Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 154 additions & 42 deletions datafusion-postgres/src/hooks/set_show.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ use std::sync::Arc;
use async_trait::async_trait;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::{ParamValues, ToDFSchema};
use datafusion::error::DataFusionError;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::SessionContext;
use datafusion::sql::sqlparser::ast::{Set, Statement};
use datafusion::sql::sqlparser::ast::{Expr, Set, Statement};
use log::{info, warn};
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
use pgwire::api::ClientInfo;
Expand Down Expand Up @@ -134,39 +135,53 @@ where
hivevar: false,
variable,
values,
} if &variable.to_string() == "statement_timeout" => {
let value = values[0].to_string();
let timeout_str = value.trim_matches('"').trim_matches('\'');

let timeout = if timeout_str == "0" || timeout_str.is_empty() {
None
} else {
// Parse timeout value (supports ms, s, min formats)
let timeout_ms = if timeout_str.ends_with("ms") {
timeout_str.trim_end_matches("ms").parse::<u64>()
} else if timeout_str.ends_with("s") {
timeout_str
.trim_end_matches("s")
.parse::<u64>()
.map(|s| s * 1000)
} else if timeout_str.ends_with("min") {
timeout_str
.trim_end_matches("min")
.parse::<u64>()
.map(|m| m * 60 * 1000)
} => {
let var = variable.to_string().to_lowercase();
if var == "statement_timeout" {
let value = values[0].to_string();
let timeout_str = value.trim_matches('"').trim_matches('\'');

let timeout = if timeout_str == "0" || timeout_str.is_empty() {
None
} else {
// Default to milliseconds
timeout_str.parse::<u64>()
// Parse timeout value (supports ms, s, min formats)
let timeout_ms = if timeout_str.ends_with("ms") {
timeout_str.trim_end_matches("ms").parse::<u64>()
} else if timeout_str.ends_with("s") {
timeout_str
.trim_end_matches("s")
.parse::<u64>()
.map(|s| s * 1000)
} else if timeout_str.ends_with("min") {
timeout_str
.trim_end_matches("min")
.parse::<u64>()
.map(|m| m * 60 * 1000)
} else {
// Default to milliseconds
timeout_str.parse::<u64>()
};

match timeout_ms {
Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
_ => None,
}
};

match timeout_ms {
Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
_ => None,
client::set_statement_timeout(client, timeout);
return Some(Ok(Response::Execution(Tag::new("SET"))));
} else if matches!(var.as_str(), "datestyle" | "bytea_output" | "intervalstyle")
&& !values.is_empty()
{
// postgres configuration variables
let value = values[0].clone();
if let Expr::Value(value) = value {
client
.metadata_mut()
.insert(var, value.into_string().unwrap_or_else(|| "".to_string()));
return Some(Ok(Response::Execution(Tag::new("SET"))));
}
};

client::set_statement_timeout(client, timeout);
Some(Ok(Response::Execution(Tag::new("SET"))))
}
}
Set::SetTimeZone {
local: false,
Expand All @@ -175,19 +190,39 @@ where
let tz = value.to_string();
let tz = tz.trim_matches('"').trim_matches('\'');
client::set_timezone(client, Some(tz));
Some(Ok(Response::Execution(Tag::new("SET"))))
return Some(Ok(Response::Execution(Tag::new("SET"))));
}
_ => {
// pass SET query to datafusion
let query = statement.to_string();
if let Err(e) = session_context.sql(&query).await {
warn!("SET statement {query} is not supported by datafusion, error {e}, statement ignored");
}
_ => {}
}

// Always return SET success
Some(Ok(Response::Execution(Tag::new("SET"))))
}
// fallback to datafusion and ignore all errors
if let Err(e) = execute_set_statement(session_context, statement.clone()).await {
warn!(
"SET statement {} is not supported by datafusion, error {e}, statement ignored",
statement
);
}

// Always return SET success
Some(Ok(Response::Execution(Tag::new("SET"))))
}

async fn execute_set_statement(
session_context: &SessionContext,
statement: Statement,
) -> Result<(), DataFusionError> {
let state = session_context.state();
let logical_plan = state
.statement_to_plan(datafusion::sql::parser::Statement::Statement(Box::new(
statement,
)))
.await
.and_then(|logical_plan| state.optimize(&logical_plan))?;

session_context
.execute_logical_plan(logical_plan)
.await
.map(|_| ())
}

async fn try_respond_show_statements<C>(
Expand All @@ -204,10 +239,11 @@ where

let variables = variable
.iter()
.map(|v| &v.value as &str)
.map(|v| v.value.to_lowercase())
.collect::<Vec<_>>();
let variables_ref = variables.iter().map(|s| s.as_str()).collect::<Vec<_>>();

match &variables as &[&str] {
match variables_ref.as_slice() {
["time", "zone"] => {
let timezone = client::get_timezone(client).unwrap_or("UTC");
Some(mock_show_response("TimeZone", timezone).map(Response::Query))
Expand Down Expand Up @@ -238,6 +274,14 @@ where
["transaction", "isolation", "level"] => {
Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query))
}
["bytea_output"] | ["datestyle"] | ["intervalstyle"] => {
let val = client
.metadata()
.get(&variables[0])
.map(|v| v.as_str())
.unwrap_or("");
Some(mock_show_response(&variables[0], val).map(Response::Query))
}
_ => {
info!("Unsupported show statement: {}", statement);
Some(mock_show_response("unsupported_show_statement", "").map(Response::Query))
Expand Down Expand Up @@ -288,6 +332,74 @@ mod tests {
assert!(show_response.unwrap().is_ok());
}

#[tokio::test]
async fn test_bytea_output_set_and_show() {
let session_context = SessionContext::new();
let mut client = MockClient::new();

// Test setting timeout to 5000ms
let statement = Parser::new(&PostgreSqlDialect {})
.try_with_sql("set bytea_output = 'hex'")
.unwrap()
.parse_statement()
.unwrap();
let set_response =
try_respond_set_statements(&mut client, &statement, &session_context).await;

assert!(set_response.is_some());
assert!(set_response.unwrap().is_ok());

// Verify the timeout was set in client metadata
let bytea_output = client.metadata().get("bytea_output").unwrap();
assert_eq!(bytea_output, "hex");

// Test SHOW statement_timeout
let statement = Parser::new(&PostgreSqlDialect {})
.try_with_sql("show bytea_output")
.unwrap()
.parse_statement()
.unwrap();
let show_response =
try_respond_show_statements(&client, &statement, &session_context).await;

assert!(show_response.is_some());
assert!(show_response.unwrap().is_ok());
}

#[tokio::test]
async fn test_date_style_set_and_show() {
let session_context = SessionContext::new();
let mut client = MockClient::new();

// Test setting timeout to 5000ms
let statement = Parser::new(&PostgreSqlDialect {})
.try_with_sql("set dateStyle = 'ISO, DMY'")
.unwrap()
.parse_statement()
.unwrap();
let set_response =
try_respond_set_statements(&mut client, &statement, &session_context).await;

assert!(set_response.is_some());
assert!(set_response.unwrap().is_ok());

// Verify the timeout was set in client metadata
let bytea_output = client.metadata().get("datestyle").unwrap();
assert_eq!(bytea_output, "ISO, DMY");

// Test SHOW statement_timeout
let statement = Parser::new(&PostgreSqlDialect {})
.try_with_sql("show dateStyle")
.unwrap()
.parse_statement()
.unwrap();
let show_response =
try_respond_show_statements(&client, &statement, &session_context).await;

assert!(show_response.is_some());
assert!(show_response.unwrap().is_ok());
}

#[tokio::test]
async fn test_statement_timeout_disable() {
let session_context = SessionContext::new();
Expand Down
Loading