11// src/handlers.rs
2+
23use std:: collections:: HashMap ;
34use std:: sync:: Arc ;
45
@@ -21,6 +22,7 @@ use pgwire::error::{PgWireError, PgWireResult};
2122use sqlparser:: ast:: { Expr , Ident , ObjectName , Statement } ;
2223use sqlparser:: dialect:: GenericDialect ;
2324use sqlparser:: parser:: Parser as SqlParser ;
25+ use tokio:: sync:: RwLock ;
2426
2527use crate :: datatypes:: { self , into_pg_type} ;
2628
@@ -56,37 +58,41 @@ impl PgWireServerHandlers for HandlerFactory {
5658 }
5759}
5860
61+
5962pub struct DfSessionService {
60- pub session_context : Arc < tokio :: sync :: RwLock < SessionContext > > ,
63+ pub session_context : Arc < RwLock < SessionContext > > ,
6164 pub parser : Arc < Parser > ,
62- custom_session_vars : Arc < tokio :: sync :: RwLock < HashMap < String , String > > > ,
65+ custom_session_vars : Arc < RwLock < HashMap < String , String > > > ,
6366}
6467
6568impl DfSessionService {
6669 pub fn new ( session_context : SessionContext ) -> DfSessionService {
67- let session_context = Arc :: new ( tokio :: sync :: RwLock :: new ( session_context) ) ;
70+ let session_context = Arc :: new ( RwLock :: new ( session_context) ) ;
6871 let parser = Arc :: new ( Parser {
6972 session_context : session_context. clone ( ) ,
7073 } ) ;
7174 DfSessionService {
7275 session_context,
7376 parser,
74- custom_session_vars : Arc :: new ( tokio :: sync :: RwLock :: new ( HashMap :: new ( ) ) ) ,
77+ custom_session_vars : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
7578 }
7679 }
7780
7881 async fn handle_set ( & self , variable : & ObjectName , value : & [ Expr ] ) -> PgWireResult < ( ) > {
79- let var_name = variable. 0 . get ( 0 )
82+ let var_name = variable
83+ . 0
84+ . get ( 0 )
8085 . map ( |ident| ident. to_string ( ) . to_lowercase ( ) )
8186 . unwrap_or_default ( ) ;
87+
8288 let value_str = match value. get ( 0 ) {
8389 Some ( Expr :: Value ( v) ) => match & v. value {
8490 sqlparser:: ast:: Value :: SingleQuotedString ( s)
8591 | sqlparser:: ast:: Value :: DoubleQuotedString ( s) => s. clone ( ) ,
8692 sqlparser:: ast:: Value :: Number ( n, _) => n. to_string ( ) ,
8793 _ => v. to_string ( ) ,
8894 } ,
89- Some ( expr) => expr. to_string ( ) ,
95+ Some ( expr) => expr. to_string ( ) ,
9096 None => {
9197 return Err ( PgWireError :: UserError ( Box :: new ( pgwire:: error:: ErrorInfo :: new (
9298 "ERROR" . to_string ( ) ,
@@ -98,86 +104,110 @@ impl DfSessionService {
98104
99105 match var_name. as_str ( ) {
100106 "timezone" => {
101- let config = {
102- let ctx = self . session_context . read ( ) . await ;
103- ctx. state ( ) . config ( ) . options ( ) . clone ( )
104- } ;
105- let mut new_config = config;
106- new_config. execution . time_zone = Some ( value_str) ;
107- let new_context = SessionContext :: new_with_config ( new_config. into ( ) ) ;
108- {
109- let ctx = self . session_context . read ( ) . await ;
110- for catalog_name in ctx. catalog_names ( ) {
111- if let Some ( catalog) = ctx. catalog ( & catalog_name) {
112- for schema_name in catalog. schema_names ( ) {
113- if let Some ( schema) = catalog. schema ( & schema_name) {
114- for table_name in schema. table_names ( ) {
115- if let Ok ( Some ( table) ) = schema. table ( & table_name) . await {
116- new_context. register_table ( & table_name, table)
117- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
118- }
107+ let mut sc_guard = self . session_context . write ( ) . await ;
108+
109+ let mut config = sc_guard. state ( ) . config ( ) . options ( ) . clone ( ) ;
110+ config. execution . time_zone = Some ( value_str) ;
111+
112+ let new_context = SessionContext :: new_with_config ( config. into ( ) ) ;
113+
114+ let old_catalog_names = sc_guard. catalog_names ( ) ;
115+ for catalog_name in old_catalog_names {
116+ if let Some ( catalog) = sc_guard. catalog ( & catalog_name) {
117+ for schema_name in catalog. schema_names ( ) {
118+ if let Some ( schema) = catalog. schema ( & schema_name) {
119+ for table_name in schema. table_names ( ) {
120+ if let Ok ( Some ( table) ) = schema. table ( & table_name) . await {
121+ new_context
122+ . register_table ( & table_name, table)
123+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
119124 }
120125 }
121126 }
122127 }
123128 }
124129 }
125- {
126- let mut ctx = self . session_context . write ( ) . await ;
127- * ctx = new_context;
128- }
130+
131+ * sc_guard = new_context;
129132 Ok ( ( ) )
130133 }
131- "client_encoding" | "search_path" | "application_name" => {
134+ "client_encoding" | "search_path" | "application_name" | "datestyle" => {
132135 let mut vars = self . custom_session_vars . write ( ) . await ;
133136 vars. insert ( var_name, value_str) ;
134137 Ok ( ( ) )
135138 }
136139 _ => Err ( PgWireError :: UserError ( Box :: new ( pgwire:: error:: ErrorInfo :: new (
137140 "ERROR" . to_string ( ) ,
138- "42704" . to_string ( ) , // Undefined object
141+ "42704" . to_string ( ) ,
139142 format ! ( "Unrecognized configuration parameter '{}'" , var_name) ,
140143 ) ) ) ) ,
141144 }
142145 }
143146
144147 async fn handle_show < ' a > ( & self , variable : & [ Ident ] ) -> PgWireResult < QueryResponse < ' a > > {
145- let var_name = variable. get ( 0 )
148+ let var_name = variable
149+ . get ( 0 )
146150 . map ( |ident| ident. to_string ( ) . to_lowercase ( ) )
147151 . unwrap_or_default ( ) ;
148- let config = {
149- let ctx = self . session_context . read ( ) . await ;
150- ctx . state ( ) . config ( ) . options ( ) . clone ( )
151- } ;
152+
153+ let sc_guard = self . session_context . read ( ) . await ;
154+ let config = sc_guard . state ( ) . config ( ) . options ( ) . clone ( ) ;
155+ drop ( sc_guard ) ;
152156 let value = match var_name. as_str ( ) {
153- "timezone" => config. execution . time_zone . clone ( ) . unwrap_or_else ( || "UTC" . to_string ( ) ) ,
154- "client_encoding" => self . custom_session_vars
155- . read ( ) . await
157+ "timezone" => config
158+ . execution
159+ . time_zone
160+ . clone ( )
161+ . unwrap_or_else ( || "UTC" . to_string ( ) ) ,
162+
163+ "client_encoding" => self
164+ . custom_session_vars
165+ . read ( )
166+ . await
156167 . get ( & var_name)
157168 . cloned ( )
158169 . unwrap_or_else ( || "UTF8" . to_string ( ) ) ,
159- "search_path" => self . custom_session_vars
160- . read ( ) . await
170+
171+ "search_path" => self
172+ . custom_session_vars
173+ . read ( )
174+ . await
161175 . get ( & var_name)
162176 . cloned ( )
163177 . unwrap_or_else ( || "public" . to_string ( ) ) ,
164- "application_name" => self . custom_session_vars
165- . read ( ) . await
178+
179+ "application_name" => self
180+ . custom_session_vars
181+ . read ( )
182+ . await
166183 . get ( & var_name)
167184 . cloned ( )
168185 . unwrap_or_else ( || "" . to_string ( ) ) ,
186+
187+ "datestyle" => self
188+ . custom_session_vars
189+ . read ( )
190+ . await
191+ . get ( & var_name)
192+ . cloned ( )
193+ . unwrap_or_else ( || "ISO, MDY" . to_string ( ) ) ,
194+
169195 "all" => {
170196 let mut names = Vec :: new ( ) ;
171197 let mut values = Vec :: new ( ) ;
198+
172199 if let Some ( tz) = & config. execution . time_zone {
173200 names. push ( "timezone" . to_string ( ) ) ;
174201 values. push ( tz. clone ( ) ) ;
175202 }
203+
176204 let custom_vars = self . custom_session_vars . read ( ) . await ;
177205 for ( name, value) in custom_vars. iter ( ) {
178206 names. push ( name. clone ( ) ) ;
179207 values. push ( value. clone ( ) ) ;
180208 }
209+
210+ // Provide defaults if not set
181211 if !custom_vars. contains_key ( "client_encoding" ) {
182212 names. push ( "client_encoding" . to_string ( ) ) ;
183213 values. push ( "UTF8" . to_string ( ) ) ;
@@ -190,10 +220,16 @@ impl DfSessionService {
190220 names. push ( "application_name" . to_string ( ) ) ;
191221 values. push ( "" . to_string ( ) ) ;
192222 }
223+ if !custom_vars. contains_key ( "datestyle" ) {
224+ names. push ( "datestyle" . to_string ( ) ) ;
225+ values. push ( "ISO, MDY" . to_string ( ) ) ;
226+ }
227+
193228 let schema = Arc :: new ( Schema :: new ( vec ! [
194229 Field :: new( "name" , DataType :: Utf8 , false ) ,
195230 Field :: new( "setting" , DataType :: Utf8 , false ) ,
196231 ] ) ) ;
232+
197233 let batch = RecordBatch :: try_new (
198234 schema. clone ( ) ,
199235 vec ! [
@@ -202,11 +238,13 @@ impl DfSessionService {
202238 ] ,
203239 )
204240 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
205- let df = {
206- let ctx = self . session_context . read ( ) . await ;
207- ctx. read_batch ( batch)
208- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?
209- } ;
241+
242+ let sc_guard = self . session_context . read ( ) . await ;
243+ let df = sc_guard
244+ . read_batch ( batch)
245+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
246+ drop ( sc_guard) ;
247+
210248 return datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ;
211249 }
212250 _ => {
@@ -218,42 +256,40 @@ impl DfSessionService {
218256 }
219257 } ;
220258
221- let schema = Arc :: new ( Schema :: new ( vec ! [
222- Field :: new( & var_name, DataType :: Utf8 , false ) ,
223- ] ) ) ;
224- let batch = RecordBatch :: try_new (
225- schema. clone ( ) ,
226- vec ! [ Arc :: new( StringArray :: from( vec![ value] ) ) ] ,
227- )
228- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
229- let df = {
230- let ctx = self . session_context . read ( ) . await ;
231- ctx. read_batch ( batch)
232- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?
233- } ;
259+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( & var_name, DataType :: Utf8 , false ) ] ) ) ;
260+ let batch = RecordBatch :: try_new ( schema. clone ( ) , vec ! [ Arc :: new( StringArray :: from( vec![ value] ) ) ] )
261+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
262+
263+ let sc_guard = self . session_context . read ( ) . await ;
264+ let df = sc_guard
265+ . read_batch ( batch)
266+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
267+ drop ( sc_guard) ;
268+
234269 datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await
235270 }
236271}
237272
238273pub struct Parser {
239- session_context : Arc < tokio :: sync :: RwLock < SessionContext > > ,
274+ session_context : Arc < RwLock < SessionContext > > ,
240275}
241276
242277#[ async_trait]
243278impl QueryParser for Parser {
244279 type Statement = LogicalPlan ;
245280
246281 async fn parse_sql ( & self , sql : & str , _types : & [ Type ] ) -> PgWireResult < Self :: Statement > {
247- let ctx = self . session_context . read ( ) . await ;
248- let logical_plan = ctx
249- . state ( )
282+ let sc_guard = self . session_context . read ( ) . await ;
283+ let state = sc_guard. state ( ) ;
284+
285+ let logical_plan = state
250286 . create_logical_plan ( sql)
251287 . await
252288 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
253- let optimized = ctx
254- . state ( )
289+ let optimized = state
255290 . optimize ( & logical_plan)
256291 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
292+
257293 Ok ( optimized)
258294 }
259295}
@@ -295,12 +331,13 @@ impl SimpleQueryHandler for DfSessionService {
295331 responses. push ( Response :: Query ( resp) ) ;
296332 }
297333 _ => {
298- let df = {
299- let ctx = self . session_context . read ( ) . await ;
300- ctx. sql ( & stmt_string)
301- . await
302- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?
303- } ;
334+ let sc_guard = self . session_context . read ( ) . await ;
335+ let df = sc_guard
336+ . sql ( & stmt_string)
337+ . await
338+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
339+ drop ( sc_guard) ;
340+
304341 let resp = datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
305342 responses. push ( Response :: Query ( resp) ) ;
306343 }
@@ -361,8 +398,8 @@ impl ExtendedQueryHandler for DfSessionService {
361398 let plan = & target. statement . statement ;
362399 let format = & target. result_column_format ;
363400 let schema = plan. schema ( ) ;
364- let fields =
365- datatypes :: df_schema_to_pg_fields ( schema . as_ref ( ) , format ) ? ;
401+ let fields = datatypes :: df_schema_to_pg_fields ( schema . as_ref ( ) , format ) ? ;
402+
366403 Ok ( DescribePortalResponse :: new ( fields) )
367404 }
368405
@@ -388,9 +425,7 @@ impl ExtendedQueryHandler for DfSessionService {
388425 sqlparser:: ast:: OneOrManyWithParens :: Many ( ref names) => names. first ( ) . unwrap ( ) ,
389426 } ;
390427 self . handle_set ( var, & value) . await ?;
391- return Ok ( Response :: Execution (
392- pgwire:: api:: results:: Tag :: new ( "SET" ) . into ( ) ,
393- ) ) ;
428+ return Ok ( Response :: Execution ( pgwire:: api:: results:: Tag :: new ( "SET" ) . into ( ) ) ) ;
394429 }
395430 } else if stmt_upper. starts_with ( "SHOW " ) {
396431 let dialect = GenericDialect { } ;
@@ -406,20 +441,20 @@ impl ExtendedQueryHandler for DfSessionService {
406441 let param_types = plan
407442 . get_parameter_types ( )
408443 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
409- let param_values = datatypes:: deserialize_parameters (
410- portal,
411- & ordered_param_types ( & param_types) ,
412- ) ?;
444+ let param_values =
445+ datatypes:: deserialize_parameters ( portal, & ordered_param_types ( & param_types) ) ?;
413446 let plan = plan
414447 . clone ( )
415448 . replace_params_with_values ( & param_values)
416449 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
417- let dataframe = {
418- let ctx = self . session_context . read ( ) . await ;
419- ctx. execute_logical_plan ( plan)
420- . await
421- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?
422- } ;
450+
451+ let sc_guard = self . session_context . read ( ) . await ;
452+ let dataframe = sc_guard
453+ . execute_logical_plan ( plan)
454+ . await
455+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
456+ drop ( sc_guard) ;
457+
423458 let resp = datatypes:: encode_dataframe ( dataframe, & portal. result_column_format ) . await ?;
424459 Ok ( Response :: Query ( resp) )
425460 }
0 commit comments