Skip to content

Commit c566109

Browse files
committed
fix(pgwire): Add missing SHOW parameters for psql compatibility
1 parent 522f7b8 commit c566109

File tree

1 file changed

+59
-48
lines changed

1 file changed

+59
-48
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 59 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ impl DfSessionService {
124124
*sc_guard = new_context;
125125
Ok(())
126126
}
127-
"client_encoding"
127+
| "client_encoding"
128128
| "search_path"
129129
| "application_name"
130130
| "datestyle"
@@ -154,7 +154,6 @@ impl DfSessionService {
154154

155155
let sc_guard = self.session_context.read().await;
156156
let config = sc_guard.state().config().options().clone();
157-
drop(sc_guard);
158157

159158
let value = match var_name.as_str() {
160159
"timezone" => config
@@ -232,6 +231,17 @@ impl DfSessionService {
232231
.get(&var_name)
233232
.cloned()
234233
.unwrap_or_else(|| "read committed".to_string()),
234+
235+
// *** New variables to keep psql happy ***
236+
"server_version" => "14.0".to_string(),
237+
"server_version_num" => "140000".to_string(),
238+
"server_encoding" => "UTF8".to_string(),
239+
"is_superuser" => "off".to_string(),
240+
"lc_messages" => "en_US.UTF-8".to_string(),
241+
"lc_monetary" => "en_US.UTF-8".to_string(),
242+
"lc_numeric" => "en_US.UTF-8".to_string(),
243+
"lc_time" => "en_US.UTF-8".to_string(),
244+
235245
"all" => {
236246
let mut names = Vec::new();
237247
let mut values = Vec::new();
@@ -240,50 +250,39 @@ impl DfSessionService {
240250
names.push("timezone".to_string());
241251
values.push(tz.clone());
242252
}
253+
243254
let custom_vars = self.custom_session_vars.read().await;
244255
for (name, value) in custom_vars.iter() {
245256
names.push(name.clone());
246257
values.push(value.clone());
247258
}
248-
if !custom_vars.contains_key("client_encoding") {
249-
names.push("client_encoding".to_string());
250-
values.push("UTF8".to_string());
251-
}
252-
if !custom_vars.contains_key("search_path") {
253-
names.push("search_path".to_string());
254-
values.push("public".to_string());
255-
}
256-
if !custom_vars.contains_key("application_name") {
257-
names.push("application_name".to_string());
258-
values.push("".to_string());
259-
}
260-
if !custom_vars.contains_key("datestyle") {
261-
names.push("datestyle".to_string());
262-
values.push("ISO, MDY".to_string());
263-
}
264-
if !custom_vars.contains_key("client_min_messages") {
265-
names.push("client_min_messages".to_string());
266-
values.push("notice".to_string());
267-
}
268-
if !custom_vars.contains_key("extra_float_digits") {
269-
names.push("extra_float_digits".to_string());
270-
values.push("3".to_string());
271-
}
272-
if !custom_vars.contains_key("standard_conforming_strings") {
273-
names.push("standard_conforming_strings".to_string());
274-
values.push("on".to_string());
275-
}
276-
if !custom_vars.contains_key("check_function_bodies") {
277-
names.push("check_function_bodies".to_string());
278-
values.push("off".to_string());
279-
}
280-
if !custom_vars.contains_key("transaction_read_only") {
281-
names.push("transaction_read_only".to_string());
282-
values.push("off".to_string());
283-
}
284-
if !custom_vars.contains_key("transaction_isolation") {
285-
names.push("transaction_isolation".to_string());
286-
values.push("read committed".to_string());
259+
260+
let defaults = vec![
261+
("client_encoding", "UTF8"),
262+
("search_path", "public"),
263+
("application_name", ""),
264+
("datestyle", "ISO, MDY"),
265+
("client_min_messages", "notice"),
266+
("extra_float_digits", "3"),
267+
("standard_conforming_strings", "on"),
268+
("check_function_bodies", "off"),
269+
("transaction_read_only", "off"),
270+
("transaction_isolation", "read committed"),
271+
("server_version", "14.0"),
272+
("server_version_num", "140000"),
273+
("server_encoding", "UTF8"),
274+
("is_superuser", "off"),
275+
("lc_messages", "en_US.UTF-8"),
276+
("lc_monetary", "en_US.UTF-8"),
277+
("lc_numeric", "en_US.UTF-8"),
278+
("lc_time", "en_US.UTF-8"),
279+
];
280+
281+
for (k, v) in defaults {
282+
if !names.contains(&k.to_string()) {
283+
names.push(k.to_string());
284+
values.push(v.to_string());
285+
}
287286
}
288287

289288
let schema = Arc::new(Schema::new(vec![
@@ -298,13 +297,13 @@ impl DfSessionService {
298297
],
299298
)
300299
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
301-
let sc_guard = self.session_context.read().await;
300+
302301
let df = sc_guard
303302
.read_batch(batch)
304303
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
305-
drop(sc_guard);
306304
return datatypes::encode_dataframe(df, &Format::UnifiedText).await;
307305
}
306+
308307
_ => {
309308
return Err(PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
310309
"ERROR".to_string(),
@@ -315,13 +314,12 @@ impl DfSessionService {
315314
};
316315

317316
let schema = Arc::new(Schema::new(vec![Field::new(&var_name, DataType::Utf8, false)]));
318-
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(StringArray::from(vec![value]))])
317+
let batch = RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(vec![value]))])
319318
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
320-
let sc_guard = self.session_context.read().await;
321319
let df = sc_guard
322320
.read_batch(batch)
323321
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
324-
drop(sc_guard);
322+
325323
datatypes::encode_dataframe(df, &Format::UnifiedText).await
326324
}
327325
}
@@ -333,6 +331,7 @@ pub struct Parser {
333331
#[async_trait]
334332
impl QueryParser for Parser {
335333
type Statement = LogicalPlan;
334+
336335
async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
337336
let sc_guard = self.session_context.read().await;
338337
let state = sc_guard.state();
@@ -361,6 +360,7 @@ impl SimpleQueryHandler for DfSessionService {
361360
let stmts = SqlParser::parse_sql(&dialect, query)
362361
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
363362
let mut responses = Vec::with_capacity(stmts.len());
363+
364364
for statement in stmts {
365365
let stmt_string = statement.to_string().trim().to_owned();
366366
if stmt_string.is_empty() {
@@ -387,7 +387,6 @@ impl SimpleQueryHandler for DfSessionService {
387387
.sql(&stmt_string)
388388
.await
389389
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
390-
drop(sc_guard);
391390
let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
392391
responses.push(Response::Query(resp));
393392
}
@@ -401,9 +400,11 @@ impl SimpleQueryHandler for DfSessionService {
401400
impl ExtendedQueryHandler for DfSessionService {
402401
type Statement = LogicalPlan;
403402
type QueryParser = Parser;
403+
404404
fn query_parser(&self) -> Arc<Self::QueryParser> {
405405
self.parser.clone()
406406
}
407+
407408
async fn do_describe_statement<C>(
408409
&self,
409410
_client: &mut C,
@@ -420,6 +421,7 @@ impl ExtendedQueryHandler for DfSessionService {
420421
.get_parameter_types()
421422
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
422423
let mut param_types = Vec::with_capacity(params.len());
424+
423425
for param_type in ordered_param_types(&params).iter() {
424426
if let Some(datatype) = param_type {
425427
let pgtype = into_pg_type(datatype)?;
@@ -430,6 +432,7 @@ impl ExtendedQueryHandler for DfSessionService {
430432
}
431433
Ok(DescribeStatementResponse::new(param_types, fields))
432434
}
435+
433436
async fn do_describe_portal<C>(
434437
&self,
435438
_client: &mut C,
@@ -444,6 +447,7 @@ impl ExtendedQueryHandler for DfSessionService {
444447
let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), format)?;
445448
Ok(DescribePortalResponse::new(fields))
446449
}
450+
447451
async fn do_query<'a, C>(
448452
&self,
449453
_client: &mut C,
@@ -455,6 +459,8 @@ impl ExtendedQueryHandler for DfSessionService {
455459
{
456460
let stmt_string = portal.statement.id.clone();
457461
let stmt_upper = stmt_string.to_uppercase();
462+
463+
// If the statement is a SET or SHOW, handle it here
458464
if stmt_upper.starts_with("SET ") {
459465
let dialect = GenericDialect {};
460466
let stmts = SqlParser::parse_sql(&dialect, &stmt_string)
@@ -476,6 +482,8 @@ impl ExtendedQueryHandler for DfSessionService {
476482
return Ok(Response::Query(resp));
477483
}
478484
}
485+
486+
// Otherwise, treat it as a normal prepared statement
479487
let plan = &portal.statement.statement;
480488
let param_types = plan
481489
.get_parameter_types()
@@ -486,12 +494,13 @@ impl ExtendedQueryHandler for DfSessionService {
486494
.clone()
487495
.replace_params_with_values(&param_values)
488496
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
497+
489498
let sc_guard = self.session_context.read().await;
490499
let dataframe = sc_guard
491500
.execute_logical_plan(plan)
492501
.await
493502
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
494-
drop(sc_guard);
503+
495504
let resp = datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?;
496505
Ok(Response::Query(resp))
497506
}
@@ -500,6 +509,8 @@ impl ExtendedQueryHandler for DfSessionService {
500509
fn ordered_param_types(
501510
types: &HashMap<String, Option<DataType>>,
502511
) -> Vec<Option<&DataType>> {
512+
// Datafusion stores the parameters as a map. In our case, the keys will be
513+
// `$1`, `$2` etc. The values will be the parameter types.
503514
let mut types_vec = types.iter().collect::<Vec<_>>();
504515
types_vec.sort_by(|a, b| a.0.cmp(b.0));
505516
types_vec.into_iter().map(|pt| pt.1.as_ref()).collect()

0 commit comments

Comments
 (0)