Skip to content

Commit c80d2ff

Browse files
committed
Intercept current_schemas query by returning a constant list and register UDF for compatibility
1 parent 69b48bc commit c80d2ff

File tree

1 file changed

+141
-91
lines changed

1 file changed

+141
-91
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 141 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ use std::collections::HashMap;
22
use std::sync::Arc;
33

44
use async_trait::async_trait;
5-
use datafusion::arrow::array::StringArray;
5+
use datafusion::arrow::array::{ListBuilder, StringArray, StringBuilder};
66
use datafusion::arrow::datatypes::{DataType, Field, Schema};
77
use datafusion::arrow::record_batch::RecordBatch;
8-
use datafusion::logical_expr::LogicalPlan;
8+
use datafusion::logical_expr::{create_udf, ColumnarValue, LogicalPlan, Volatility};
99
use datafusion::prelude::*;
1010
use pgwire::api::auth::noop::NoopStartupHandler;
1111
use pgwire::api::copy::NoopCopyHandler;
@@ -17,7 +17,7 @@ use pgwire::api::results::{
1717
use pgwire::api::stmt::{QueryParser, StoredStatement};
1818
use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type};
1919
use pgwire::error::{PgWireError, PgWireResult};
20-
use sqlparser::ast::{Expr, Ident, ObjectName, Statement};
20+
use sqlparser::ast::{Expr, Ident, ObjectName, Statement, ObjectNamePart};
2121
use sqlparser::dialect::GenericDialect;
2222
use sqlparser::parser::Parser as SqlParser;
2323
use tokio::sync::RwLock;
@@ -75,12 +75,31 @@ impl DfSessionService {
7575
}
7676
}
7777

78+
/// Call this method to register additional UDFs (such as current_schemas)
79+
pub async fn register_udfs(&self) -> datafusion::error::Result<()> {
80+
let mut ctx = self.session_context.write().await;
81+
register_current_schemas_udf(&mut ctx)?;
82+
Ok(())
83+
}
84+
85+
/// Helper function to read a custom session variable, returning the default if not set.
86+
async fn session_var(&self, key: &str, default: &str) -> String {
87+
self.custom_session_vars
88+
.read()
89+
.await
90+
.get(key)
91+
.cloned()
92+
.unwrap_or_else(|| default.to_string())
93+
}
94+
7895
async fn handle_set(&self, variable: &ObjectName, value: &[Expr]) -> PgWireResult<()> {
96+
// Join all parts of the ObjectName so that "TIME ZONE" becomes "timezone"
7997
let var_name = variable
8098
.0
81-
.first()
82-
.map(|ident| ident.to_string().to_lowercase())
83-
.unwrap_or_default();
99+
.iter()
100+
.map(|ident| ident.to_string())
101+
.collect::<String>()
102+
.to_lowercase();
84103

85104
let value_str = match value.first() {
86105
Some(Expr::Value(v)) => match &v.value {
@@ -151,90 +170,33 @@ impl DfSessionService {
151170
}
152171

153172
async fn handle_show<'a>(&self, variable: &[Ident]) -> PgWireResult<QueryResponse<'a>> {
173+
// Join all identifiers so that "TIME ZONE" becomes "timezone"
154174
let var_name = variable
155-
.first()
156-
.map(|ident| ident.to_string().to_lowercase())
157-
.unwrap_or_default();
175+
.iter()
176+
.map(|ident| ident.to_string())
177+
.collect::<String>()
178+
.to_lowercase();
158179

159180
let sc_guard = self.session_context.read().await;
160181
let config = sc_guard.state().config().options().clone();
161182

162183
let value = match var_name.as_str() {
163-
"timezone" => config
184+
// Support both "timezone" and "time" so that pgcli/psql are happy.
185+
"timezone" | "time" => config
164186
.execution
165187
.time_zone
166188
.clone()
167189
.unwrap_or_else(|| "UTC".to_string()),
168-
"client_encoding" => self
169-
.custom_session_vars
170-
.read()
171-
.await
172-
.get(&var_name)
173-
.cloned()
174-
.unwrap_or_else(|| "UTF8".to_string()),
175-
"search_path" => self
176-
.custom_session_vars
177-
.read()
178-
.await
179-
.get(&var_name)
180-
.cloned()
181-
.unwrap_or_else(|| "public".to_string()),
182-
"application_name" => self
183-
.custom_session_vars
184-
.read()
185-
.await
186-
.get(&var_name)
187-
.cloned()
188-
.unwrap_or_else(|| "".to_string()),
189-
"datestyle" => self
190-
.custom_session_vars
191-
.read()
192-
.await
193-
.get(&var_name)
194-
.cloned()
195-
.unwrap_or_else(|| "ISO, MDY".to_string()),
196-
"client_min_messages" => self
197-
.custom_session_vars
198-
.read()
199-
.await
200-
.get(&var_name)
201-
.cloned()
202-
.unwrap_or_else(|| "notice".to_string()),
203-
"extra_float_digits" => self
204-
.custom_session_vars
205-
.read()
206-
.await
207-
.get(&var_name)
208-
.cloned()
209-
.unwrap_or_else(|| "3".to_string()),
210-
"standard_conforming_strings" => self
211-
.custom_session_vars
212-
.read()
213-
.await
214-
.get(&var_name)
215-
.cloned()
216-
.unwrap_or_else(|| "on".to_string()),
217-
"check_function_bodies" => self
218-
.custom_session_vars
219-
.read()
220-
.await
221-
.get(&var_name)
222-
.cloned()
223-
.unwrap_or_else(|| "off".to_string()),
224-
"transaction_read_only" => self
225-
.custom_session_vars
226-
.read()
227-
.await
228-
.get(&var_name)
229-
.cloned()
230-
.unwrap_or_else(|| "off".to_string()),
231-
"transaction_isolation" => self
232-
.custom_session_vars
233-
.read()
234-
.await
235-
.get(&var_name)
236-
.cloned()
237-
.unwrap_or_else(|| "read committed".to_string()),
190+
"client_encoding" => self.session_var("client_encoding", "UTF8").await,
191+
"search_path" => self.session_var("search_path", "public").await,
192+
"application_name" => self.session_var("application_name", "").await,
193+
"datestyle" => self.session_var("datestyle", "ISO, MDY").await,
194+
"client_min_messages" => self.session_var("client_min_messages", "notice").await,
195+
"extra_float_digits" => self.session_var("extra_float_digits", "3").await,
196+
"standard_conforming_strings" => self.session_var("standard_conforming_strings", "on").await,
197+
"check_function_bodies" => self.session_var("check_function_bodies", "off").await,
198+
"transaction_read_only" => self.session_var("transaction_read_only", "off").await,
199+
"transaction_isolation" => self.session_var("transaction_isolation", "read committed").await,
238200

239201
// *** New variables to keep psql happy ***
240202
"server_version" => "14.0".to_string(),
@@ -280,6 +242,7 @@ impl DfSessionService {
280242
("lc_monetary", "en_US.UTF-8"),
281243
("lc_numeric", "en_US.UTF-8"),
282244
("lc_time", "en_US.UTF-8"),
245+
("time", "UTC"),
283246
];
284247

285248
for (k, v) in defaults {
@@ -291,7 +254,11 @@ impl DfSessionService {
291254

292255
let schema = Arc::new(Schema::new(vec![
293256
Field::new("name", DataType::Utf8, false),
294-
Field::new("setting", DataType::Utf8, false),
257+
Field::new(
258+
"setting",
259+
DataType::List(Box::new(Field::new("item", DataType::Utf8, true)).into()),
260+
false,
261+
),
295262
]));
296263
let batch = RecordBatch::try_new(
297264
schema.clone(),
@@ -366,6 +333,55 @@ impl SimpleQueryHandler for DfSessionService {
366333
where
367334
C: ClientInfo + Unpin + Send + Sync,
368335
{
336+
let query_trimmed = query.trim();
337+
let query_lower = query_trimmed.to_lowercase();
338+
339+
// Intercept SELECT current_schemas(...) queries.
340+
if query_lower.starts_with("select current_schemas(") {
341+
// Build a StringArray with "public"
342+
let mut string_builder = StringBuilder::new();
343+
string_builder.append_value("public");
344+
// Build a ListArray containing "public"
345+
let mut list_builder = ListBuilder::new(StringBuilder::new());
346+
list_builder.values().append_value("public");
347+
list_builder.append(true);
348+
let list_array = list_builder.finish();
349+
350+
// Define schema for a single column "current_schemas" of type List(Utf8)
351+
let field = Field::new(
352+
"current_schemas",
353+
DataType::List(Box::new(Field::new("item", DataType::Utf8, true)).into()),
354+
false,
355+
);
356+
let schema = Arc::new(Schema::new(vec![field]));
357+
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(list_array)])
358+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
359+
let sc_guard = self.session_context.read().await;
360+
let df = sc_guard
361+
.read_batch(batch)
362+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
363+
let encoded = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
364+
return Ok(vec![Response::Query(encoded)]);
365+
}
366+
367+
// Intercept SET TIME ZONE commands to handle them directly.
368+
if query_lower.starts_with("set time zone") {
369+
let parts: Vec<&str> = query_trimmed.split_whitespace().collect();
370+
if parts.len() >= 4 {
371+
let tz = parts[3].trim_matches('\'').trim_matches('"');
372+
let object_name =
373+
ObjectName(vec![ObjectNamePart::Identifier(Ident::new("timezone"))]);
374+
let expr = Expr::Value(
375+
sqlparser::ast::Value::SingleQuotedString(tz.to_string()).into(),
376+
);
377+
self.handle_set(&object_name, &[expr]).await?;
378+
return Ok(vec![Response::Execution(
379+
pgwire::api::results::Tag::new("SET"),
380+
)]);
381+
}
382+
}
383+
384+
// Otherwise, process the query normally.
369385
let dialect = GenericDialect {};
370386
let stmts = SqlParser::parse_sql(&dialect, query)
371387
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
@@ -382,12 +398,12 @@ impl SimpleQueryHandler for DfSessionService {
382398
} => {
383399
let var = match variables {
384400
sqlparser::ast::OneOrManyWithParens::One(ref name) => name,
385-
sqlparser::ast::OneOrManyWithParens::Many(ref names) => {
386-
names.first().unwrap()
387-
}
401+
sqlparser::ast::OneOrManyWithParens::Many(ref names) => names.first().unwrap(),
388402
};
389403
self.handle_set(var, &value).await?;
390-
responses.push(Response::Execution(pgwire::api::results::Tag::new("SET")));
404+
responses.push(Response::Execution(
405+
pgwire::api::results::Tag::new("SET"),
406+
));
391407
}
392408
Statement::ShowVariable { variable } => {
393409
let resp = self.handle_show(&variable).await?;
@@ -427,7 +443,8 @@ impl ExtendedQueryHandler for DfSessionService {
427443
{
428444
let plan = &target.statement;
429445
let schema = plan.schema();
430-
let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), &Format::UnifiedBinary)?;
446+
let fields =
447+
datatypes::df_schema_to_pg_fields(schema.as_ref(), &Format::UnifiedBinary)?;
431448
let params = plan
432449
.get_parameter_types()
433450
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
@@ -485,7 +502,9 @@ impl ExtendedQueryHandler for DfSessionService {
485502
sqlparser::ast::OneOrManyWithParens::Many(ref names) => names.first().unwrap(),
486503
};
487504
self.handle_set(var, value).await?;
488-
return Ok(Response::Execution(pgwire::api::results::Tag::new("SET")));
505+
return Ok(Response::Execution(
506+
pgwire::api::results::Tag::new("SET"),
507+
));
489508
}
490509
} else if stmt_upper.starts_with("SHOW ") {
491510
let dialect = GenericDialect {};
@@ -497,7 +516,7 @@ impl ExtendedQueryHandler for DfSessionService {
497516
}
498517
}
499518

500-
// Otherwise, treat it as a normal prepared statement
519+
// Otherwise, treat it as a normal prepared statement.
501520
let plan = &portal.statement.statement;
502521
let param_types = plan
503522
.get_parameter_types()
@@ -515,15 +534,46 @@ impl ExtendedQueryHandler for DfSessionService {
515534
.await
516535
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
517536

518-
let resp = datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?;
537+
let resp =
538+
datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?;
519539
Ok(Response::Query(resp))
520540
}
521541
}
522542

523543
fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
524-
// Datafusion stores the parameters as a map. In our case, the keys will be
525-
// `$1`, `$2` etc. The values will be the parameter types.
544+
// Datafusion stores the parameters as a map. In our case, the keys will be
545+
// `$1`, `$2` etc. The values will be the parameter types.
526546
let mut types_vec = types.iter().collect::<Vec<_>>();
527547
types_vec.sort_by(|a, b| a.0.cmp(b.0));
528548
types_vec.into_iter().map(|pt| pt.1.as_ref()).collect()
529549
}
550+
551+
/// Register a UDF called `current_schemas` that takes a boolean and returns an array containing "public".
552+
fn register_current_schemas_udf(ctx: &mut SessionContext) -> datafusion::error::Result<()> {
553+
let current_schemas_fn = Arc::new(move |args: &[ColumnarValue]| -> datafusion::error::Result<ColumnarValue> {
554+
// We ignore the input value; just return a constant list containing "public".
555+
let num_rows = match &args[0] {
556+
ColumnarValue::Array(array) => array.len(),
557+
ColumnarValue::Scalar(_) => 1,
558+
};
559+
// Build a ListArray containing "public"
560+
let mut list_builder = ListBuilder::new(StringBuilder::new());
561+
for _ in 0..num_rows {
562+
list_builder.values().append_value("public");
563+
list_builder.append(true);
564+
}
565+
let list_array = list_builder.finish();
566+
Ok(ColumnarValue::Array(Arc::new(list_array)))
567+
});
568+
569+
let udf = create_udf(
570+
"current_schemas",
571+
vec![DataType::Boolean],
572+
DataType::List(Box::new(Field::new("item", DataType::Utf8, true)).into()),
573+
Volatility::Immutable,
574+
current_schemas_fn,
575+
);
576+
577+
ctx.register_udf(udf);
578+
Ok(())
579+
}

0 commit comments

Comments
 (0)