Skip to content

Commit bf68272

Browse files
committed
Rework SET statement handling
Use proper syntax for set statement timeout statements. Replace much of the string matching for SET statements with with matching on proper ast types.
1 parent efff49d commit bf68272

File tree

1 file changed

+80
-89
lines changed

1 file changed

+80
-89
lines changed

datafusion-postgres/src/hooks/set_show.rs

Lines changed: 80 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema};
55
use datafusion::common::{ParamValues, ToDFSchema};
66
use datafusion::logical_expr::LogicalPlan;
77
use datafusion::prelude::SessionContext;
8-
use datafusion::sql::sqlparser::ast::Statement;
8+
use datafusion::sql::sqlparser::ast::{Set, Statement};
99
use log::{info, warn};
1010
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
1111
use pgwire::api::ClientInfo;
@@ -29,10 +29,7 @@ impl QueryHook for SetShowHook {
2929
) -> Option<PgWireResult<Response>> {
3030
match statement {
3131
Statement::Set { .. } => {
32-
let query = statement.to_string();
33-
let query_lower = query.to_lowercase();
34-
35-
try_respond_set_statements(client, &query_lower, session_context).await
32+
try_respond_set_statements(client, &statement, session_context).await
3633
}
3734
Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
3835
let query = statement.to_string();
@@ -93,10 +90,7 @@ impl QueryHook for SetShowHook {
9390
) -> Option<PgWireResult<Response>> {
9491
match statement {
9592
Statement::Set { .. } => {
96-
let query = statement.to_string();
97-
let query_lower = query.to_lowercase();
98-
99-
try_respond_set_statements(client, &query_lower, session_context).await
93+
try_respond_set_statements(client, &statement, session_context).await
10094
}
10195
Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
10296
let query = statement.to_string();
@@ -130,84 +124,75 @@ fn mock_show_response(name: &str, value: &str) -> PgWireResult<QueryResponse> {
130124

131125
async fn try_respond_set_statements<C>(
132126
client: &mut C,
133-
query_lower: &str,
127+
statement: &Statement,
134128
session_context: &SessionContext,
135129
) -> Option<PgWireResult<Response>>
136130
where
137131
C: ClientInfo + Send + Sync + ?Sized,
138132
{
139-
if query_lower.starts_with("set") {
140-
let result = if query_lower.starts_with("set time zone") {
141-
let parts: Vec<&str> = query_lower.split_whitespace().collect();
142-
if parts.len() >= 4 {
143-
let tz = parts[3].trim_matches('"');
144-
client::set_timezone(client, Some(tz));
145-
Ok(Response::Execution(Tag::new("SET")))
146-
} else {
147-
Err(PgWireError::UserError(Box::new(
148-
pgwire::error::ErrorInfo::new(
149-
"ERROR".to_string(),
150-
"42601".to_string(),
151-
"Invalid SET TIME ZONE syntax".to_string(),
152-
),
153-
)))
154-
}
155-
} else if query_lower.starts_with("set statement_timeout") {
156-
let parts: Vec<&str> = query_lower.split_whitespace().collect();
157-
if parts.len() >= 3 {
158-
let timeout_str = parts[2].trim_matches('"').trim_matches('\'');
133+
let Statement::Set(set_statement) = statement else {
134+
return None;
135+
};
159136

160-
let timeout = if timeout_str == "0" || timeout_str.is_empty() {
161-
None
137+
match &set_statement {
138+
Set::SingleAssignment {
139+
scope: None,
140+
hivevar: false,
141+
variable,
142+
values,
143+
} if &variable.to_string() == "statement_timeout" => {
144+
let value = values[0].to_string();
145+
let timeout_str = value.trim_matches('"').trim_matches('\'');
146+
147+
let timeout = if timeout_str == "0" || timeout_str.is_empty() {
148+
None
149+
} else {
150+
// Parse timeout value (supports ms, s, min formats)
151+
let timeout_ms = if timeout_str.ends_with("ms") {
152+
timeout_str.trim_end_matches("ms").parse::<u64>()
153+
} else if timeout_str.ends_with("s") {
154+
timeout_str
155+
.trim_end_matches("s")
156+
.parse::<u64>()
157+
.map(|s| s * 1000)
158+
} else if timeout_str.ends_with("min") {
159+
timeout_str
160+
.trim_end_matches("min")
161+
.parse::<u64>()
162+
.map(|m| m * 60 * 1000)
162163
} else {
163-
// Parse timeout value (supports ms, s, min formats)
164-
let timeout_ms = if timeout_str.ends_with("ms") {
165-
timeout_str.trim_end_matches("ms").parse::<u64>()
166-
} else if timeout_str.ends_with("s") {
167-
timeout_str
168-
.trim_end_matches("s")
169-
.parse::<u64>()
170-
.map(|s| s * 1000)
171-
} else if timeout_str.ends_with("min") {
172-
timeout_str
173-
.trim_end_matches("min")
174-
.parse::<u64>()
175-
.map(|m| m * 60 * 1000)
176-
} else {
177-
// Default to milliseconds
178-
timeout_str.parse::<u64>()
179-
};
180-
181-
match timeout_ms {
182-
Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
183-
_ => None,
184-
}
164+
// Default to milliseconds
165+
timeout_str.parse::<u64>()
185166
};
186167

187-
client::set_statement_timeout(client, timeout);
188-
Ok(Response::Execution(Tag::new("SET")))
189-
} else {
190-
Err(PgWireError::UserError(Box::new(
191-
pgwire::error::ErrorInfo::new(
192-
"ERROR".to_string(),
193-
"42601".to_string(),
194-
"Invalid SET statement_timeout syntax".to_string(),
195-
),
196-
)))
197-
}
198-
} else {
168+
match timeout_ms {
169+
Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
170+
_ => None,
171+
}
172+
};
173+
174+
client::set_statement_timeout(client, timeout);
175+
Some(Ok(Response::Execution(Tag::new("SET"))))
176+
}
177+
Set::SetTimeZone {
178+
local: false,
179+
value,
180+
} => {
181+
let tz = value.to_string();
182+
let tz = tz.trim_matches('"').trim_matches('\'');
183+
client::set_timezone(client, Some(&tz));
184+
Some(Ok(Response::Execution(Tag::new("SET"))))
185+
}
186+
_ => {
199187
// pass SET query to datafusion
200-
if let Err(e) = session_context.sql(query_lower).await {
201-
warn!("SET statement {query_lower} is not supported by datafusion, error {e}, statement ignored");
188+
let query = statement.to_string();
189+
if let Err(e) = session_context.sql(&query).await {
190+
warn!("SET statement {query} is not supported by datafusion, error {e}, statement ignored");
202191
}
203192

204193
// Always return SET success
205-
Ok(Response::Execution(Tag::new("SET")))
206-
};
207-
208-
Some(result)
209-
} else {
210-
None
194+
Some(Ok(Response::Execution(Tag::new("SET"))))
195+
}
211196
}
212197
}
213198

@@ -266,6 +251,8 @@ where
266251
mod tests {
267252
use std::time::Duration;
268253

254+
use datafusion::sql::sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
255+
269256
use super::*;
270257
use crate::testing::MockClient;
271258

@@ -275,12 +262,13 @@ mod tests {
275262
let mut client = MockClient::new();
276263

277264
// Test setting timeout to 5000ms
278-
let set_response = try_respond_set_statements(
279-
&mut client,
280-
"set statement_timeout '5000ms'",
281-
&session_context,
282-
)
283-
.await;
265+
let statement = Parser::new(&PostgreSqlDialect {})
266+
.try_with_sql("set statement_timeout to '5000ms'")
267+
.unwrap()
268+
.parse_statement()
269+
.unwrap();
270+
let set_response =
271+
try_respond_set_statements(&mut client, &statement, &session_context).await;
284272

285273
assert!(set_response.is_some());
286274
assert!(set_response.unwrap().is_ok());
@@ -303,19 +291,22 @@ mod tests {
303291
let mut client = MockClient::new();
304292

305293
// Set timeout first
306-
let resp = try_respond_set_statements(
307-
&mut client,
308-
"set statement_timeout '1000ms'",
309-
&session_context,
310-
)
311-
.await;
294+
let statement = Parser::new(&PostgreSqlDialect {})
295+
.try_with_sql("set statement_timeout to '1000ms'")
296+
.unwrap()
297+
.parse_statement()
298+
.unwrap();
299+
let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
312300
assert!(resp.is_some());
313301
assert!(resp.unwrap().is_ok());
314302

315303
// Disable timeout with 0
316-
let resp =
317-
try_respond_set_statements(&mut client, "set statement_timeout '0'", &session_context)
318-
.await;
304+
let statement = Parser::new(&PostgreSqlDialect {})
305+
.try_with_sql("set statement_timeout to '0'")
306+
.unwrap()
307+
.parse_statement()
308+
.unwrap();
309+
let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
319310
assert!(resp.is_some());
320311
assert!(resp.unwrap().is_ok());
321312

0 commit comments

Comments
 (0)