Skip to content

Commit 522f7b8

Browse files
committed
Improve PGWire compatibility by handling extra client parameters
1 parent 4eaabdb commit 522f7b8

File tree

1 file changed

+81
-45
lines changed

1 file changed

+81
-45
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 81 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
// src/handlers.rs
2-
31
use std::collections::HashMap;
42
use std::sync::Arc;
53

@@ -58,7 +56,6 @@ impl PgWireServerHandlers for HandlerFactory {
5856
}
5957
}
6058

61-
6259
pub struct DfSessionService {
6360
pub session_context: Arc<RwLock<SessionContext>>,
6461
pub parser: Arc<Parser>,
@@ -96,7 +93,7 @@ impl DfSessionService {
9693
None => {
9794
return Err(PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
9895
"ERROR".to_string(),
99-
"22023".to_string(),
96+
"22023".to_string(),
10097
"SET requires a value".to_string(),
10198
))));
10299
}
@@ -105,12 +102,9 @@ impl DfSessionService {
105102
match var_name.as_str() {
106103
"timezone" => {
107104
let mut sc_guard = self.session_context.write().await;
108-
109105
let mut config = sc_guard.state().config().options().clone();
110106
config.execution.time_zone = Some(value_str);
111-
112107
let new_context = SessionContext::new_with_config(config.into());
113-
114108
let old_catalog_names = sc_guard.catalog_names();
115109
for catalog_name in old_catalog_names {
116110
if let Some(catalog) = sc_guard.catalog(&catalog_name) {
@@ -127,11 +121,19 @@ impl DfSessionService {
127121
}
128122
}
129123
}
130-
131124
*sc_guard = new_context;
132125
Ok(())
133126
}
134-
"client_encoding" | "search_path" | "application_name" | "datestyle" => {
127+
"client_encoding"
128+
| "search_path"
129+
| "application_name"
130+
| "datestyle"
131+
| "client_min_messages"
132+
| "extra_float_digits"
133+
| "standard_conforming_strings"
134+
| "check_function_bodies"
135+
| "transaction_read_only"
136+
| "transaction_isolation" => {
135137
let mut vars = self.custom_session_vars.write().await;
136138
vars.insert(var_name, value_str);
137139
Ok(())
@@ -152,46 +154,84 @@ impl DfSessionService {
152154

153155
let sc_guard = self.session_context.read().await;
154156
let config = sc_guard.state().config().options().clone();
155-
drop(sc_guard);
157+
drop(sc_guard);
158+
156159
let value = match var_name.as_str() {
157160
"timezone" => config
158161
.execution
159162
.time_zone
160163
.clone()
161164
.unwrap_or_else(|| "UTC".to_string()),
162-
163165
"client_encoding" => self
164166
.custom_session_vars
165167
.read()
166168
.await
167169
.get(&var_name)
168170
.cloned()
169171
.unwrap_or_else(|| "UTF8".to_string()),
170-
171172
"search_path" => self
172173
.custom_session_vars
173174
.read()
174175
.await
175176
.get(&var_name)
176177
.cloned()
177178
.unwrap_or_else(|| "public".to_string()),
178-
179179
"application_name" => self
180180
.custom_session_vars
181181
.read()
182182
.await
183183
.get(&var_name)
184184
.cloned()
185185
.unwrap_or_else(|| "".to_string()),
186-
187186
"datestyle" => self
188187
.custom_session_vars
189188
.read()
190189
.await
191190
.get(&var_name)
192191
.cloned()
193192
.unwrap_or_else(|| "ISO, MDY".to_string()),
194-
193+
"client_min_messages" => self
194+
.custom_session_vars
195+
.read()
196+
.await
197+
.get(&var_name)
198+
.cloned()
199+
.unwrap_or_else(|| "notice".to_string()),
200+
"extra_float_digits" => self
201+
.custom_session_vars
202+
.read()
203+
.await
204+
.get(&var_name)
205+
.cloned()
206+
.unwrap_or_else(|| "3".to_string()),
207+
"standard_conforming_strings" => self
208+
.custom_session_vars
209+
.read()
210+
.await
211+
.get(&var_name)
212+
.cloned()
213+
.unwrap_or_else(|| "on".to_string()),
214+
"check_function_bodies" => self
215+
.custom_session_vars
216+
.read()
217+
.await
218+
.get(&var_name)
219+
.cloned()
220+
.unwrap_or_else(|| "off".to_string()),
221+
"transaction_read_only" => self
222+
.custom_session_vars
223+
.read()
224+
.await
225+
.get(&var_name)
226+
.cloned()
227+
.unwrap_or_else(|| "off".to_string()),
228+
"transaction_isolation" => self
229+
.custom_session_vars
230+
.read()
231+
.await
232+
.get(&var_name)
233+
.cloned()
234+
.unwrap_or_else(|| "read committed".to_string()),
195235
"all" => {
196236
let mut names = Vec::new();
197237
let mut values = Vec::new();
@@ -200,14 +240,11 @@ impl DfSessionService {
200240
names.push("timezone".to_string());
201241
values.push(tz.clone());
202242
}
203-
204243
let custom_vars = self.custom_session_vars.read().await;
205244
for (name, value) in custom_vars.iter() {
206245
names.push(name.clone());
207246
values.push(value.clone());
208247
}
209-
210-
// Provide defaults if not set
211248
if !custom_vars.contains_key("client_encoding") {
212249
names.push("client_encoding".to_string());
213250
values.push("UTF8".to_string());
@@ -224,12 +261,35 @@ impl DfSessionService {
224261
names.push("datestyle".to_string());
225262
values.push("ISO, MDY".to_string());
226263
}
264+
if !custom_vars.contains_key("client_min_messages") {
265+
names.push("client_min_messages".to_string());
266+
values.push("notice".to_string());
267+
}
268+
if !custom_vars.contains_key("extra_float_digits") {
269+
names.push("extra_float_digits".to_string());
270+
values.push("3".to_string());
271+
}
272+
if !custom_vars.contains_key("standard_conforming_strings") {
273+
names.push("standard_conforming_strings".to_string());
274+
values.push("on".to_string());
275+
}
276+
if !custom_vars.contains_key("check_function_bodies") {
277+
names.push("check_function_bodies".to_string());
278+
values.push("off".to_string());
279+
}
280+
if !custom_vars.contains_key("transaction_read_only") {
281+
names.push("transaction_read_only".to_string());
282+
values.push("off".to_string());
283+
}
284+
if !custom_vars.contains_key("transaction_isolation") {
285+
names.push("transaction_isolation".to_string());
286+
values.push("read committed".to_string());
287+
}
227288

228289
let schema = Arc::new(Schema::new(vec![
229290
Field::new("name", DataType::Utf8, false),
230291
Field::new("setting", DataType::Utf8, false),
231292
]));
232-
233293
let batch = RecordBatch::try_new(
234294
schema.clone(),
235295
vec![
@@ -238,13 +298,11 @@ impl DfSessionService {
238298
],
239299
)
240300
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
241-
242301
let sc_guard = self.session_context.read().await;
243302
let df = sc_guard
244303
.read_batch(batch)
245304
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
246305
drop(sc_guard);
247-
248306
return datatypes::encode_dataframe(df, &Format::UnifiedText).await;
249307
}
250308
_ => {
@@ -259,37 +317,32 @@ impl DfSessionService {
259317
let schema = Arc::new(Schema::new(vec![Field::new(&var_name, DataType::Utf8, false)]));
260318
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(StringArray::from(vec![value]))])
261319
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
262-
263320
let sc_guard = self.session_context.read().await;
264321
let df = sc_guard
265322
.read_batch(batch)
266323
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
267324
drop(sc_guard);
268-
269325
datatypes::encode_dataframe(df, &Format::UnifiedText).await
270326
}
271327
}
272328

273329
pub struct Parser {
274-
session_context: Arc<RwLock<SessionContext>>,
330+
pub session_context: Arc<RwLock<SessionContext>>,
275331
}
276332

277333
#[async_trait]
278334
impl QueryParser for Parser {
279335
type Statement = LogicalPlan;
280-
281336
async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
282337
let sc_guard = self.session_context.read().await;
283338
let state = sc_guard.state();
284-
285339
let logical_plan = state
286340
.create_logical_plan(sql)
287341
.await
288342
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
289343
let optimized = state
290344
.optimize(&logical_plan)
291345
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
292-
293346
Ok(optimized)
294347
}
295348
}
@@ -307,14 +360,12 @@ impl SimpleQueryHandler for DfSessionService {
307360
let dialect = GenericDialect {};
308361
let stmts = SqlParser::parse_sql(&dialect, query)
309362
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
310-
311363
let mut responses = Vec::with_capacity(stmts.len());
312364
for statement in stmts {
313365
let stmt_string = statement.to_string().trim().to_owned();
314366
if stmt_string.is_empty() {
315367
continue;
316368
}
317-
318369
match statement {
319370
Statement::SetVariable { variables, value, .. } => {
320371
let var = match variables {
@@ -337,13 +388,11 @@ impl SimpleQueryHandler for DfSessionService {
337388
.await
338389
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
339390
drop(sc_guard);
340-
341391
let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
342392
responses.push(Response::Query(resp));
343393
}
344394
}
345395
}
346-
347396
Ok(responses)
348397
}
349398
}
@@ -352,11 +401,9 @@ impl SimpleQueryHandler for DfSessionService {
352401
impl ExtendedQueryHandler for DfSessionService {
353402
type Statement = LogicalPlan;
354403
type QueryParser = Parser;
355-
356404
fn query_parser(&self) -> Arc<Self::QueryParser> {
357405
self.parser.clone()
358406
}
359-
360407
async fn do_describe_statement<C>(
361408
&self,
362409
_client: &mut C,
@@ -366,14 +413,12 @@ impl ExtendedQueryHandler for DfSessionService {
366413
C: ClientInfo + Unpin + Send + Sync,
367414
{
368415
let plan = &target.statement;
369-
370416
let schema = plan.schema();
371417
let fields =
372418
datatypes::df_schema_to_pg_fields(schema.as_ref(), &Format::UnifiedBinary)?;
373419
let params = plan
374420
.get_parameter_types()
375421
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
376-
377422
let mut param_types = Vec::with_capacity(params.len());
378423
for param_type in ordered_param_types(&params).iter() {
379424
if let Some(datatype) = param_type {
@@ -383,10 +428,8 @@ impl ExtendedQueryHandler for DfSessionService {
383428
param_types.push(Type::UNKNOWN);
384429
}
385430
}
386-
387431
Ok(DescribeStatementResponse::new(param_types, fields))
388432
}
389-
390433
async fn do_describe_portal<C>(
391434
&self,
392435
_client: &mut C,
@@ -399,10 +442,8 @@ impl ExtendedQueryHandler for DfSessionService {
399442
let format = &target.result_column_format;
400443
let schema = plan.schema();
401444
let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), format)?;
402-
403445
Ok(DescribePortalResponse::new(fields))
404446
}
405-
406447
async fn do_query<'a, C>(
407448
&self,
408449
_client: &mut C,
@@ -412,9 +453,8 @@ impl ExtendedQueryHandler for DfSessionService {
412453
where
413454
C: ClientInfo + Unpin + Send + Sync,
414455
{
415-
let stmt_string = portal.statement.id.clone();
456+
let stmt_string = portal.statement.id.clone();
416457
let stmt_upper = stmt_string.to_uppercase();
417-
418458
if stmt_upper.starts_with("SET ") {
419459
let dialect = GenericDialect {};
420460
let stmts = SqlParser::parse_sql(&dialect, &stmt_string)
@@ -436,7 +476,6 @@ impl ExtendedQueryHandler for DfSessionService {
436476
return Ok(Response::Query(resp));
437477
}
438478
}
439-
440479
let plan = &portal.statement.statement;
441480
let param_types = plan
442481
.get_parameter_types()
@@ -447,20 +486,17 @@ impl ExtendedQueryHandler for DfSessionService {
447486
.clone()
448487
.replace_params_with_values(&param_values)
449488
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
450-
451489
let sc_guard = self.session_context.read().await;
452490
let dataframe = sc_guard
453491
.execute_logical_plan(plan)
454492
.await
455493
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
456494
drop(sc_guard);
457-
458495
let resp = datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?;
459496
Ok(Response::Query(resp))
460497
}
461498
}
462499

463-
/// Helper to convert DataFusion’s parameter map into an ordered list.
464500
fn ordered_param_types(
465501
types: &HashMap<String, Option<DataType>>,
466502
) -> Vec<Option<&DataType>> {

0 commit comments

Comments
 (0)