Skip to content

Commit 5324ad2

Browse files
committed
feat: add support for postgres variables of formatting
1 parent c0a0834 commit 5324ad2

File tree

1 file changed

+154
-42
lines changed

1 file changed

+154
-42
lines changed

datafusion-postgres/src/hooks/set_show.rs

Lines changed: 154 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ use std::sync::Arc;
33
use async_trait::async_trait;
44
use datafusion::arrow::datatypes::{DataType, Field, Schema};
55
use datafusion::common::{ParamValues, ToDFSchema};
6+
use datafusion::error::DataFusionError;
67
use datafusion::logical_expr::LogicalPlan;
78
use datafusion::prelude::SessionContext;
8-
use datafusion::sql::sqlparser::ast::{Set, Statement};
9+
use datafusion::sql::sqlparser::ast::{Expr, Set, Statement};
910
use log::{info, warn};
1011
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
1112
use pgwire::api::ClientInfo;
@@ -134,39 +135,53 @@ where
134135
hivevar: false,
135136
variable,
136137
values,
137-
} if &variable.to_string() == "statement_timeout" => {
138-
let value = values[0].to_string();
139-
let timeout_str = value.trim_matches('"').trim_matches('\'');
140-
141-
let timeout = if timeout_str == "0" || timeout_str.is_empty() {
142-
None
143-
} else {
144-
// Parse timeout value (supports ms, s, min formats)
145-
let timeout_ms = if timeout_str.ends_with("ms") {
146-
timeout_str.trim_end_matches("ms").parse::<u64>()
147-
} else if timeout_str.ends_with("s") {
148-
timeout_str
149-
.trim_end_matches("s")
150-
.parse::<u64>()
151-
.map(|s| s * 1000)
152-
} else if timeout_str.ends_with("min") {
153-
timeout_str
154-
.trim_end_matches("min")
155-
.parse::<u64>()
156-
.map(|m| m * 60 * 1000)
138+
} => {
139+
let var = variable.to_string().to_lowercase();
140+
if var == "statement_timeout" {
141+
let value = values[0].to_string();
142+
let timeout_str = value.trim_matches('"').trim_matches('\'');
143+
144+
let timeout = if timeout_str == "0" || timeout_str.is_empty() {
145+
None
157146
} else {
158-
// Default to milliseconds
159-
timeout_str.parse::<u64>()
147+
// Parse timeout value (supports ms, s, min formats)
148+
let timeout_ms = if timeout_str.ends_with("ms") {
149+
timeout_str.trim_end_matches("ms").parse::<u64>()
150+
} else if timeout_str.ends_with("s") {
151+
timeout_str
152+
.trim_end_matches("s")
153+
.parse::<u64>()
154+
.map(|s| s * 1000)
155+
} else if timeout_str.ends_with("min") {
156+
timeout_str
157+
.trim_end_matches("min")
158+
.parse::<u64>()
159+
.map(|m| m * 60 * 1000)
160+
} else {
161+
// Default to milliseconds
162+
timeout_str.parse::<u64>()
163+
};
164+
165+
match timeout_ms {
166+
Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
167+
_ => None,
168+
}
160169
};
161170

162-
match timeout_ms {
163-
Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
164-
_ => None,
171+
client::set_statement_timeout(client, timeout);
172+
return Some(Ok(Response::Execution(Tag::new("SET"))));
173+
} else if matches!(var.as_str(), "datestyle" | "bytea_output" | "intervalstyle") {
174+
if values.len() > 0 {
175+
// postgres configuration variables
176+
let value = values[0].clone();
177+
if let Expr::Value(value) = value {
178+
client
179+
.metadata_mut()
180+
.insert(var, value.into_string().unwrap_or_else(|| "".to_string()));
181+
return Some(Ok(Response::Execution(Tag::new("SET"))));
182+
}
165183
}
166-
};
167-
168-
client::set_statement_timeout(client, timeout);
169-
Some(Ok(Response::Execution(Tag::new("SET"))))
184+
}
170185
}
171186
Set::SetTimeZone {
172187
local: false,
@@ -175,19 +190,39 @@ where
175190
let tz = value.to_string();
176191
let tz = tz.trim_matches('"').trim_matches('\'');
177192
client::set_timezone(client, Some(tz));
178-
Some(Ok(Response::Execution(Tag::new("SET"))))
193+
return Some(Ok(Response::Execution(Tag::new("SET"))));
179194
}
180-
_ => {
181-
// pass SET query to datafusion
182-
let query = statement.to_string();
183-
if let Err(e) = session_context.sql(&query).await {
184-
warn!("SET statement {query} is not supported by datafusion, error {e}, statement ignored");
185-
}
195+
_ => {}
196+
}
186197

187-
// Always return SET success
188-
Some(Ok(Response::Execution(Tag::new("SET"))))
189-
}
198+
// fallback to datafusion and ignore all errors
199+
if let Err(e) = execute_set_statement(session_context, statement.clone()).await {
200+
warn!(
201+
"SET statement {} is not supported by datafusion, error {e}, statement ignored",
202+
statement.to_string()
203+
);
190204
}
205+
206+
// Always return SET success
207+
Some(Ok(Response::Execution(Tag::new("SET"))))
208+
}
209+
210+
async fn execute_set_statement(
211+
session_context: &SessionContext,
212+
statement: Statement,
213+
) -> Result<(), DataFusionError> {
214+
let state = session_context.state();
215+
let logical_plan = state
216+
.statement_to_plan(datafusion::sql::parser::Statement::Statement(Box::new(
217+
statement,
218+
)))
219+
.await
220+
.and_then(|logical_plan| state.optimize(&logical_plan))?;
221+
222+
session_context
223+
.execute_logical_plan(logical_plan)
224+
.await
225+
.map(|_| ())
191226
}
192227

193228
async fn try_respond_show_statements<C>(
@@ -204,10 +239,11 @@ where
204239

205240
let variables = variable
206241
.iter()
207-
.map(|v| &v.value as &str)
242+
.map(|v| v.value.to_lowercase())
208243
.collect::<Vec<_>>();
244+
let variables_ref = variables.iter().map(|s| s.as_str()).collect::<Vec<_>>();
209245

210-
match &variables as &[&str] {
246+
match variables_ref.as_slice() {
211247
["time", "zone"] => {
212248
let timezone = client::get_timezone(client).unwrap_or("UTC");
213249
Some(mock_show_response("TimeZone", timezone).map(Response::Query))
@@ -238,6 +274,14 @@ where
238274
["transaction", "isolation", "level"] => {
239275
Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query))
240276
}
277+
["bytea_output"] | ["datestyle"] | ["intervalstyle"] => {
278+
let val = client
279+
.metadata()
280+
.get(&variables[0])
281+
.map(|v| v.as_str())
282+
.unwrap_or("");
283+
Some(mock_show_response(&variables[0], val).map(Response::Query))
284+
}
241285
_ => {
242286
info!("Unsupported show statement: {}", statement);
243287
Some(mock_show_response("unsupported_show_statement", "").map(Response::Query))
@@ -288,6 +332,74 @@ mod tests {
288332
assert!(show_response.unwrap().is_ok());
289333
}
290334

335+
#[tokio::test]
336+
async fn test_bytea_output_set_and_show() {
337+
let session_context = SessionContext::new();
338+
let mut client = MockClient::new();
339+
340+
// Test setting timeout to 5000ms
341+
let statement = Parser::new(&PostgreSqlDialect {})
342+
.try_with_sql("set bytea_output = 'hex'")
343+
.unwrap()
344+
.parse_statement()
345+
.unwrap();
346+
let set_response =
347+
try_respond_set_statements(&mut client, &statement, &session_context).await;
348+
349+
assert!(set_response.is_some());
350+
assert!(set_response.unwrap().is_ok());
351+
352+
// Verify the timeout was set in client metadata
353+
let bytea_output = client.metadata().get("bytea_output").unwrap();
354+
assert_eq!(bytea_output, "hex");
355+
356+
// Test SHOW statement_timeout
357+
let statement = Parser::new(&PostgreSqlDialect {})
358+
.try_with_sql("show bytea_output")
359+
.unwrap()
360+
.parse_statement()
361+
.unwrap();
362+
let show_response =
363+
try_respond_show_statements(&client, &statement, &session_context).await;
364+
365+
assert!(show_response.is_some());
366+
assert!(show_response.unwrap().is_ok());
367+
}
368+
369+
#[tokio::test]
370+
async fn test_date_style_set_and_show() {
371+
let session_context = SessionContext::new();
372+
let mut client = MockClient::new();
373+
374+
// Test setting timeout to 5000ms
375+
let statement = Parser::new(&PostgreSqlDialect {})
376+
.try_with_sql("set dateStyle = 'ISO, DMY'")
377+
.unwrap()
378+
.parse_statement()
379+
.unwrap();
380+
let set_response =
381+
try_respond_set_statements(&mut client, &statement, &session_context).await;
382+
383+
assert!(set_response.is_some());
384+
assert!(set_response.unwrap().is_ok());
385+
386+
// Verify the timeout was set in client metadata
387+
let bytea_output = client.metadata().get("datestyle").unwrap();
388+
assert_eq!(bytea_output, "ISO, DMY");
389+
390+
// Test SHOW statement_timeout
391+
let statement = Parser::new(&PostgreSqlDialect {})
392+
.try_with_sql("show dateStyle")
393+
.unwrap()
394+
.parse_statement()
395+
.unwrap();
396+
let show_response =
397+
try_respond_show_statements(&client, &statement, &session_context).await;
398+
399+
assert!(show_response.is_some());
400+
assert!(show_response.unwrap().is_ok());
401+
}
402+
291403
#[tokio::test]
292404
async fn test_statement_timeout_disable() {
293405
let session_context = SessionContext::new();

0 commit comments

Comments
 (0)