1- // src/handlers.rs
2-
31use std:: collections:: HashMap ;
42use std:: sync:: Arc ;
53
@@ -58,7 +56,6 @@ impl PgWireServerHandlers for HandlerFactory {
5856 }
5957}
6058
61-
6259pub struct DfSessionService {
6360 pub session_context : Arc < RwLock < SessionContext > > ,
6461 pub parser : Arc < Parser > ,
@@ -96,7 +93,7 @@ impl DfSessionService {
9693 None => {
9794 return Err ( PgWireError :: UserError ( Box :: new ( pgwire:: error:: ErrorInfo :: new (
9895 "ERROR" . to_string ( ) ,
99- "22023" . to_string ( ) ,
96+ "22023" . to_string ( ) ,
10097 "SET requires a value" . to_string ( ) ,
10198 ) ) ) ) ;
10299 }
@@ -105,12 +102,9 @@ impl DfSessionService {
105102 match var_name. as_str ( ) {
106103 "timezone" => {
107104 let mut sc_guard = self . session_context . write ( ) . await ;
108-
109105 let mut config = sc_guard. state ( ) . config ( ) . options ( ) . clone ( ) ;
110106 config. execution . time_zone = Some ( value_str) ;
111-
112107 let new_context = SessionContext :: new_with_config ( config. into ( ) ) ;
113-
114108 let old_catalog_names = sc_guard. catalog_names ( ) ;
115109 for catalog_name in old_catalog_names {
116110 if let Some ( catalog) = sc_guard. catalog ( & catalog_name) {
@@ -127,11 +121,19 @@ impl DfSessionService {
127121 }
128122 }
129123 }
130-
131124 * sc_guard = new_context;
132125 Ok ( ( ) )
133126 }
134- "client_encoding" | "search_path" | "application_name" | "datestyle" => {
127+ "client_encoding"
128+ | "search_path"
129+ | "application_name"
130+ | "datestyle"
131+ | "client_min_messages"
132+ | "extra_float_digits"
133+ | "standard_conforming_strings"
134+ | "check_function_bodies"
135+ | "transaction_read_only"
136+ | "transaction_isolation" => {
135137 let mut vars = self . custom_session_vars . write ( ) . await ;
136138 vars. insert ( var_name, value_str) ;
137139 Ok ( ( ) )
@@ -152,46 +154,84 @@ impl DfSessionService {
152154
153155 let sc_guard = self . session_context . read ( ) . await ;
154156 let config = sc_guard. state ( ) . config ( ) . options ( ) . clone ( ) ;
155- drop ( sc_guard) ;
157+ drop ( sc_guard) ;
158+
156159 let value = match var_name. as_str ( ) {
157160 "timezone" => config
158161 . execution
159162 . time_zone
160163 . clone ( )
161164 . unwrap_or_else ( || "UTC" . to_string ( ) ) ,
162-
163165 "client_encoding" => self
164166 . custom_session_vars
165167 . read ( )
166168 . await
167169 . get ( & var_name)
168170 . cloned ( )
169171 . unwrap_or_else ( || "UTF8" . to_string ( ) ) ,
170-
171172 "search_path" => self
172173 . custom_session_vars
173174 . read ( )
174175 . await
175176 . get ( & var_name)
176177 . cloned ( )
177178 . unwrap_or_else ( || "public" . to_string ( ) ) ,
178-
179179 "application_name" => self
180180 . custom_session_vars
181181 . read ( )
182182 . await
183183 . get ( & var_name)
184184 . cloned ( )
185185 . unwrap_or_else ( || "" . to_string ( ) ) ,
186-
187186 "datestyle" => self
188187 . custom_session_vars
189188 . read ( )
190189 . await
191190 . get ( & var_name)
192191 . cloned ( )
193192 . unwrap_or_else ( || "ISO, MDY" . to_string ( ) ) ,
194-
193+ "client_min_messages" => self
194+ . custom_session_vars
195+ . read ( )
196+ . await
197+ . get ( & var_name)
198+ . cloned ( )
199+ . unwrap_or_else ( || "notice" . to_string ( ) ) ,
200+ "extra_float_digits" => self
201+ . custom_session_vars
202+ . read ( )
203+ . await
204+ . get ( & var_name)
205+ . cloned ( )
206+ . unwrap_or_else ( || "3" . to_string ( ) ) ,
207+ "standard_conforming_strings" => self
208+ . custom_session_vars
209+ . read ( )
210+ . await
211+ . get ( & var_name)
212+ . cloned ( )
213+ . unwrap_or_else ( || "on" . to_string ( ) ) ,
214+ "check_function_bodies" => self
215+ . custom_session_vars
216+ . read ( )
217+ . await
218+ . get ( & var_name)
219+ . cloned ( )
220+ . unwrap_or_else ( || "off" . to_string ( ) ) ,
221+ "transaction_read_only" => self
222+ . custom_session_vars
223+ . read ( )
224+ . await
225+ . get ( & var_name)
226+ . cloned ( )
227+ . unwrap_or_else ( || "off" . to_string ( ) ) ,
228+ "transaction_isolation" => self
229+ . custom_session_vars
230+ . read ( )
231+ . await
232+ . get ( & var_name)
233+ . cloned ( )
234+ . unwrap_or_else ( || "read committed" . to_string ( ) ) ,
195235 "all" => {
196236 let mut names = Vec :: new ( ) ;
197237 let mut values = Vec :: new ( ) ;
@@ -200,14 +240,11 @@ impl DfSessionService {
200240 names. push ( "timezone" . to_string ( ) ) ;
201241 values. push ( tz. clone ( ) ) ;
202242 }
203-
204243 let custom_vars = self . custom_session_vars . read ( ) . await ;
205244 for ( name, value) in custom_vars. iter ( ) {
206245 names. push ( name. clone ( ) ) ;
207246 values. push ( value. clone ( ) ) ;
208247 }
209-
210- // Provide defaults if not set
211248 if !custom_vars. contains_key ( "client_encoding" ) {
212249 names. push ( "client_encoding" . to_string ( ) ) ;
213250 values. push ( "UTF8" . to_string ( ) ) ;
@@ -224,12 +261,35 @@ impl DfSessionService {
224261 names. push ( "datestyle" . to_string ( ) ) ;
225262 values. push ( "ISO, MDY" . to_string ( ) ) ;
226263 }
264+ if !custom_vars. contains_key ( "client_min_messages" ) {
265+ names. push ( "client_min_messages" . to_string ( ) ) ;
266+ values. push ( "notice" . to_string ( ) ) ;
267+ }
268+ if !custom_vars. contains_key ( "extra_float_digits" ) {
269+ names. push ( "extra_float_digits" . to_string ( ) ) ;
270+ values. push ( "3" . to_string ( ) ) ;
271+ }
272+ if !custom_vars. contains_key ( "standard_conforming_strings" ) {
273+ names. push ( "standard_conforming_strings" . to_string ( ) ) ;
274+ values. push ( "on" . to_string ( ) ) ;
275+ }
276+ if !custom_vars. contains_key ( "check_function_bodies" ) {
277+ names. push ( "check_function_bodies" . to_string ( ) ) ;
278+ values. push ( "off" . to_string ( ) ) ;
279+ }
280+ if !custom_vars. contains_key ( "transaction_read_only" ) {
281+ names. push ( "transaction_read_only" . to_string ( ) ) ;
282+ values. push ( "off" . to_string ( ) ) ;
283+ }
284+ if !custom_vars. contains_key ( "transaction_isolation" ) {
285+ names. push ( "transaction_isolation" . to_string ( ) ) ;
286+ values. push ( "read committed" . to_string ( ) ) ;
287+ }
227288
228289 let schema = Arc :: new ( Schema :: new ( vec ! [
229290 Field :: new( "name" , DataType :: Utf8 , false ) ,
230291 Field :: new( "setting" , DataType :: Utf8 , false ) ,
231292 ] ) ) ;
232-
233293 let batch = RecordBatch :: try_new (
234294 schema. clone ( ) ,
235295 vec ! [
@@ -238,13 +298,11 @@ impl DfSessionService {
238298 ] ,
239299 )
240300 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
241-
242301 let sc_guard = self . session_context . read ( ) . await ;
243302 let df = sc_guard
244303 . read_batch ( batch)
245304 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
246305 drop ( sc_guard) ;
247-
248306 return datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ;
249307 }
250308 _ => {
@@ -259,37 +317,32 @@ impl DfSessionService {
259317 let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( & var_name, DataType :: Utf8 , false ) ] ) ) ;
260318 let batch = RecordBatch :: try_new ( schema. clone ( ) , vec ! [ Arc :: new( StringArray :: from( vec![ value] ) ) ] )
261319 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
262-
263320 let sc_guard = self . session_context . read ( ) . await ;
264321 let df = sc_guard
265322 . read_batch ( batch)
266323 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
267324 drop ( sc_guard) ;
268-
269325 datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await
270326 }
271327}
272328
273329pub struct Parser {
274- session_context : Arc < RwLock < SessionContext > > ,
330+ pub session_context : Arc < RwLock < SessionContext > > ,
275331}
276332
277333#[ async_trait]
278334impl QueryParser for Parser {
279335 type Statement = LogicalPlan ;
280-
281336 async fn parse_sql ( & self , sql : & str , _types : & [ Type ] ) -> PgWireResult < Self :: Statement > {
282337 let sc_guard = self . session_context . read ( ) . await ;
283338 let state = sc_guard. state ( ) ;
284-
285339 let logical_plan = state
286340 . create_logical_plan ( sql)
287341 . await
288342 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
289343 let optimized = state
290344 . optimize ( & logical_plan)
291345 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
292-
293346 Ok ( optimized)
294347 }
295348}
@@ -307,14 +360,12 @@ impl SimpleQueryHandler for DfSessionService {
307360 let dialect = GenericDialect { } ;
308361 let stmts = SqlParser :: parse_sql ( & dialect, query)
309362 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
310-
311363 let mut responses = Vec :: with_capacity ( stmts. len ( ) ) ;
312364 for statement in stmts {
313365 let stmt_string = statement. to_string ( ) . trim ( ) . to_owned ( ) ;
314366 if stmt_string. is_empty ( ) {
315367 continue ;
316368 }
317-
318369 match statement {
319370 Statement :: SetVariable { variables, value, .. } => {
320371 let var = match variables {
@@ -337,13 +388,11 @@ impl SimpleQueryHandler for DfSessionService {
337388 . await
338389 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
339390 drop ( sc_guard) ;
340-
341391 let resp = datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
342392 responses. push ( Response :: Query ( resp) ) ;
343393 }
344394 }
345395 }
346-
347396 Ok ( responses)
348397 }
349398}
@@ -352,11 +401,9 @@ impl SimpleQueryHandler for DfSessionService {
352401impl ExtendedQueryHandler for DfSessionService {
353402 type Statement = LogicalPlan ;
354403 type QueryParser = Parser ;
355-
356404 fn query_parser ( & self ) -> Arc < Self :: QueryParser > {
357405 self . parser . clone ( )
358406 }
359-
360407 async fn do_describe_statement < C > (
361408 & self ,
362409 _client : & mut C ,
@@ -366,14 +413,12 @@ impl ExtendedQueryHandler for DfSessionService {
366413 C : ClientInfo + Unpin + Send + Sync ,
367414 {
368415 let plan = & target. statement ;
369-
370416 let schema = plan. schema ( ) ;
371417 let fields =
372418 datatypes:: df_schema_to_pg_fields ( schema. as_ref ( ) , & Format :: UnifiedBinary ) ?;
373419 let params = plan
374420 . get_parameter_types ( )
375421 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
376-
377422 let mut param_types = Vec :: with_capacity ( params. len ( ) ) ;
378423 for param_type in ordered_param_types ( & params) . iter ( ) {
379424 if let Some ( datatype) = param_type {
@@ -383,10 +428,8 @@ impl ExtendedQueryHandler for DfSessionService {
383428 param_types. push ( Type :: UNKNOWN ) ;
384429 }
385430 }
386-
387431 Ok ( DescribeStatementResponse :: new ( param_types, fields) )
388432 }
389-
390433 async fn do_describe_portal < C > (
391434 & self ,
392435 _client : & mut C ,
@@ -399,10 +442,8 @@ impl ExtendedQueryHandler for DfSessionService {
399442 let format = & target. result_column_format ;
400443 let schema = plan. schema ( ) ;
401444 let fields = datatypes:: df_schema_to_pg_fields ( schema. as_ref ( ) , format) ?;
402-
403445 Ok ( DescribePortalResponse :: new ( fields) )
404446 }
405-
406447 async fn do_query < ' a , C > (
407448 & self ,
408449 _client : & mut C ,
@@ -412,9 +453,8 @@ impl ExtendedQueryHandler for DfSessionService {
412453 where
413454 C : ClientInfo + Unpin + Send + Sync ,
414455 {
415- let stmt_string = portal. statement . id . clone ( ) ;
456+ let stmt_string = portal. statement . id . clone ( ) ;
416457 let stmt_upper = stmt_string. to_uppercase ( ) ;
417-
418458 if stmt_upper. starts_with ( "SET " ) {
419459 let dialect = GenericDialect { } ;
420460 let stmts = SqlParser :: parse_sql ( & dialect, & stmt_string)
@@ -436,7 +476,6 @@ impl ExtendedQueryHandler for DfSessionService {
436476 return Ok ( Response :: Query ( resp) ) ;
437477 }
438478 }
439-
440479 let plan = & portal. statement . statement ;
441480 let param_types = plan
442481 . get_parameter_types ( )
@@ -447,20 +486,17 @@ impl ExtendedQueryHandler for DfSessionService {
447486 . clone ( )
448487 . replace_params_with_values ( & param_values)
449488 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
450-
451489 let sc_guard = self . session_context . read ( ) . await ;
452490 let dataframe = sc_guard
453491 . execute_logical_plan ( plan)
454492 . await
455493 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
456494 drop ( sc_guard) ;
457-
458495 let resp = datatypes:: encode_dataframe ( dataframe, & portal. result_column_format ) . await ?;
459496 Ok ( Response :: Query ( resp) )
460497 }
461498}
462499
463- /// Helper to convert DataFusion’s parameter map into an ordered list.
464500fn ordered_param_types (
465501 types : & HashMap < String , Option < DataType > > ,
466502) -> Vec < Option < & DataType > > {
0 commit comments