Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 80 additions & 89 deletions datafusion-postgres/src/hooks/set_show.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::{ParamValues, ToDFSchema};
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::SessionContext;
use datafusion::sql::sqlparser::ast::Statement;
use datafusion::sql::sqlparser::ast::{Set, Statement};
use log::{info, warn};
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
use pgwire::api::ClientInfo;
Expand All @@ -29,10 +29,7 @@ impl QueryHook for SetShowHook {
) -> Option<PgWireResult<Response>> {
match statement {
Statement::Set { .. } => {
let query = statement.to_string();
let query_lower = query.to_lowercase();

try_respond_set_statements(client, &query_lower, session_context).await
try_respond_set_statements(client, &statement, session_context).await
}
Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
let query = statement.to_string();
Expand Down Expand Up @@ -93,10 +90,7 @@ impl QueryHook for SetShowHook {
) -> Option<PgWireResult<Response>> {
match statement {
Statement::Set { .. } => {
let query = statement.to_string();
let query_lower = query.to_lowercase();

try_respond_set_statements(client, &query_lower, session_context).await
try_respond_set_statements(client, &statement, session_context).await
}
Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
let query = statement.to_string();
Expand Down Expand Up @@ -130,84 +124,75 @@ fn mock_show_response(name: &str, value: &str) -> PgWireResult<QueryResponse> {

async fn try_respond_set_statements<C>(
client: &mut C,
query_lower: &str,
statement: &Statement,
session_context: &SessionContext,
) -> Option<PgWireResult<Response>>
where
C: ClientInfo + Send + Sync + ?Sized,
{
if query_lower.starts_with("set") {
let result = if query_lower.starts_with("set time zone") {
let parts: Vec<&str> = query_lower.split_whitespace().collect();
if parts.len() >= 4 {
let tz = parts[3].trim_matches('"');
client::set_timezone(client, Some(tz));
Ok(Response::Execution(Tag::new("SET")))
} else {
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"42601".to_string(),
"Invalid SET TIME ZONE syntax".to_string(),
),
)))
}
} else if query_lower.starts_with("set statement_timeout") {
let parts: Vec<&str> = query_lower.split_whitespace().collect();
if parts.len() >= 3 {
let timeout_str = parts[2].trim_matches('"').trim_matches('\'');
let Statement::Set(set_statement) = statement else {
return None;
};

let timeout = if timeout_str == "0" || timeout_str.is_empty() {
None
match &set_statement {
Set::SingleAssignment {
scope: None,
hivevar: false,
variable,
values,
} if &variable.to_string() == "statement_timeout" => {
let value = values[0].to_string();
let timeout_str = value.trim_matches('"').trim_matches('\'');

let timeout = if timeout_str == "0" || timeout_str.is_empty() {
None
} else {
// Parse timeout value (supports ms, s, min formats)
let timeout_ms = if timeout_str.ends_with("ms") {
timeout_str.trim_end_matches("ms").parse::<u64>()
} else if timeout_str.ends_with("s") {
timeout_str
.trim_end_matches("s")
.parse::<u64>()
.map(|s| s * 1000)
} else if timeout_str.ends_with("min") {
timeout_str
.trim_end_matches("min")
.parse::<u64>()
.map(|m| m * 60 * 1000)
} else {
// Parse timeout value (supports ms, s, min formats)
let timeout_ms = if timeout_str.ends_with("ms") {
timeout_str.trim_end_matches("ms").parse::<u64>()
} else if timeout_str.ends_with("s") {
timeout_str
.trim_end_matches("s")
.parse::<u64>()
.map(|s| s * 1000)
} else if timeout_str.ends_with("min") {
timeout_str
.trim_end_matches("min")
.parse::<u64>()
.map(|m| m * 60 * 1000)
} else {
// Default to milliseconds
timeout_str.parse::<u64>()
};

match timeout_ms {
Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
_ => None,
}
// Default to milliseconds
timeout_str.parse::<u64>()
};

client::set_statement_timeout(client, timeout);
Ok(Response::Execution(Tag::new("SET")))
} else {
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"42601".to_string(),
"Invalid SET statement_timeout syntax".to_string(),
),
)))
}
} else {
match timeout_ms {
Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
_ => None,
}
};

client::set_statement_timeout(client, timeout);
Some(Ok(Response::Execution(Tag::new("SET"))))
}
Set::SetTimeZone {
local: false,
value,
} => {
let tz = value.to_string();
let tz = tz.trim_matches('"').trim_matches('\'');
client::set_timezone(client, Some(&tz));
Some(Ok(Response::Execution(Tag::new("SET"))))
}
_ => {
// pass SET query to datafusion
if let Err(e) = session_context.sql(query_lower).await {
warn!("SET statement {query_lower} is not supported by datafusion, error {e}, statement ignored");
let query = statement.to_string();
if let Err(e) = session_context.sql(&query).await {
warn!("SET statement {query} is not supported by datafusion, error {e}, statement ignored");
}

// Always return SET success
Ok(Response::Execution(Tag::new("SET")))
};

Some(result)
} else {
None
Some(Ok(Response::Execution(Tag::new("SET"))))
}
}
}

Expand Down Expand Up @@ -266,6 +251,8 @@ where
mod tests {
use std::time::Duration;

use datafusion::sql::sqlparser::{dialect::PostgreSqlDialect, parser::Parser};

use super::*;
use crate::testing::MockClient;

Expand All @@ -275,12 +262,13 @@ mod tests {
let mut client = MockClient::new();

// Test setting timeout to 5000ms
let set_response = try_respond_set_statements(
&mut client,
"set statement_timeout '5000ms'",
&session_context,
)
.await;
let statement = Parser::new(&PostgreSqlDialect {})
.try_with_sql("set statement_timeout to '5000ms'")
.unwrap()
.parse_statement()
.unwrap();
let set_response =
try_respond_set_statements(&mut client, &statement, &session_context).await;

assert!(set_response.is_some());
assert!(set_response.unwrap().is_ok());
Expand All @@ -303,19 +291,22 @@ mod tests {
let mut client = MockClient::new();

// Set timeout first
let resp = try_respond_set_statements(
&mut client,
"set statement_timeout '1000ms'",
&session_context,
)
.await;
let statement = Parser::new(&PostgreSqlDialect {})
.try_with_sql("set statement_timeout to '1000ms'")
.unwrap()
.parse_statement()
.unwrap();
let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
assert!(resp.is_some());
assert!(resp.unwrap().is_ok());

// Disable timeout with 0
let resp =
try_respond_set_statements(&mut client, "set statement_timeout '0'", &session_context)
.await;
let statement = Parser::new(&PostgreSqlDialect {})
.try_with_sql("set statement_timeout to '0'")
.unwrap()
.parse_statement()
.unwrap();
let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
assert!(resp.is_some());
assert!(resp.unwrap().is_ok());

Expand Down
Loading