Skip to content

Commit 21229ab

Browse files
authored
Rework SET statement handling (#211)
* Rework SET statement handling Use proper syntax for set statement timeout statements. Replace much of the string matching for SET statements with with matching on proper ast types. * Fix clippy issues
1 parent 1823b8c commit 21229ab

File tree

1 file changed

+80
-89
lines changed

1 file changed

+80
-89
lines changed

datafusion-postgres/src/hooks/set_show.rs

Lines changed: 80 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema};
55
use datafusion::common::{ParamValues, ToDFSchema};
66
use datafusion::logical_expr::LogicalPlan;
77
use datafusion::prelude::SessionContext;
8-
use datafusion::sql::sqlparser::ast::Statement;
8+
use datafusion::sql::sqlparser::ast::{Set, Statement};
99
use log::{info, warn};
1010
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
1111
use pgwire::api::ClientInfo;
@@ -29,10 +29,7 @@ impl QueryHook for SetShowHook {
2929
) -> Option<PgWireResult<Response>> {
3030
match statement {
3131
Statement::Set { .. } => {
32-
let query = statement.to_string();
33-
let query_lower = query.to_lowercase();
34-
35-
try_respond_set_statements(client, &query_lower, session_context).await
32+
try_respond_set_statements(client, statement, session_context).await
3633
}
3734
Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
3835
let query = statement.to_string();
@@ -93,10 +90,7 @@ impl QueryHook for SetShowHook {
9390
) -> Option<PgWireResult<Response>> {
9491
match statement {
9592
Statement::Set { .. } => {
96-
let query = statement.to_string();
97-
let query_lower = query.to_lowercase();
98-
99-
try_respond_set_statements(client, &query_lower, session_context).await
93+
try_respond_set_statements(client, statement, session_context).await
10094
}
10195
Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
10296
let query = statement.to_string();
@@ -130,84 +124,75 @@ fn mock_show_response(name: &str, value: &str) -> PgWireResult<QueryResponse> {
130124

131125
async fn try_respond_set_statements<C>(
132126
client: &mut C,
133-
query_lower: &str,
127+
statement: &Statement,
134128
session_context: &SessionContext,
135129
) -> Option<PgWireResult<Response>>
136130
where
137131
C: ClientInfo + Send + Sync + ?Sized,
138132
{
139-
if query_lower.starts_with("set") {
140-
let result = if query_lower.starts_with("set time zone") {
141-
let parts: Vec<&str> = query_lower.split_whitespace().collect();
142-
if parts.len() >= 4 {
143-
let tz = parts[3].trim_matches('"');
144-
client::set_timezone(client, Some(tz));
145-
Ok(Response::Execution(Tag::new("SET")))
146-
} else {
147-
Err(PgWireError::UserError(Box::new(
148-
pgwire::error::ErrorInfo::new(
149-
"ERROR".to_string(),
150-
"42601".to_string(),
151-
"Invalid SET TIME ZONE syntax".to_string(),
152-
),
153-
)))
154-
}
155-
} else if query_lower.starts_with("set statement_timeout") {
156-
let parts: Vec<&str> = query_lower.split_whitespace().collect();
157-
if parts.len() >= 3 {
158-
let timeout_str = parts[2].trim_matches('"').trim_matches('\'');
133+
let Statement::Set(set_statement) = statement else {
134+
return None;
135+
};
159136

160-
let timeout = if timeout_str == "0" || timeout_str.is_empty() {
161-
None
137+
match &set_statement {
138+
Set::SingleAssignment {
139+
scope: None,
140+
hivevar: false,
141+
variable,
142+
values,
143+
} if &variable.to_string() == "statement_timeout" => {
144+
let value = values[0].to_string();
145+
let timeout_str = value.trim_matches('"').trim_matches('\'');
146+
147+
let timeout = if timeout_str == "0" || timeout_str.is_empty() {
148+
None
149+
} else {
150+
// Parse timeout value (supports ms, s, min formats)
151+
let timeout_ms = if timeout_str.ends_with("ms") {
152+
timeout_str.trim_end_matches("ms").parse::<u64>()
153+
} else if timeout_str.ends_with("s") {
154+
timeout_str
155+
.trim_end_matches("s")
156+
.parse::<u64>()
157+
.map(|s| s * 1000)
158+
} else if timeout_str.ends_with("min") {
159+
timeout_str
160+
.trim_end_matches("min")
161+
.parse::<u64>()
162+
.map(|m| m * 60 * 1000)
162163
} else {
163-
// Parse timeout value (supports ms, s, min formats)
164-
let timeout_ms = if timeout_str.ends_with("ms") {
165-
timeout_str.trim_end_matches("ms").parse::<u64>()
166-
} else if timeout_str.ends_with("s") {
167-
timeout_str
168-
.trim_end_matches("s")
169-
.parse::<u64>()
170-
.map(|s| s * 1000)
171-
} else if timeout_str.ends_with("min") {
172-
timeout_str
173-
.trim_end_matches("min")
174-
.parse::<u64>()
175-
.map(|m| m * 60 * 1000)
176-
} else {
177-
// Default to milliseconds
178-
timeout_str.parse::<u64>()
179-
};
180-
181-
match timeout_ms {
182-
Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
183-
_ => None,
184-
}
164+
// Default to milliseconds
165+
timeout_str.parse::<u64>()
185166
};
186167

187-
client::set_statement_timeout(client, timeout);
188-
Ok(Response::Execution(Tag::new("SET")))
189-
} else {
190-
Err(PgWireError::UserError(Box::new(
191-
pgwire::error::ErrorInfo::new(
192-
"ERROR".to_string(),
193-
"42601".to_string(),
194-
"Invalid SET statement_timeout syntax".to_string(),
195-
),
196-
)))
197-
}
198-
} else {
168+
match timeout_ms {
169+
Ok(ms) if ms > 0 => Some(std::time::Duration::from_millis(ms)),
170+
_ => None,
171+
}
172+
};
173+
174+
client::set_statement_timeout(client, timeout);
175+
Some(Ok(Response::Execution(Tag::new("SET"))))
176+
}
177+
Set::SetTimeZone {
178+
local: false,
179+
value,
180+
} => {
181+
let tz = value.to_string();
182+
let tz = tz.trim_matches('"').trim_matches('\'');
183+
client::set_timezone(client, Some(tz));
184+
Some(Ok(Response::Execution(Tag::new("SET"))))
185+
}
186+
_ => {
199187
// pass SET query to datafusion
200-
if let Err(e) = session_context.sql(query_lower).await {
201-
warn!("SET statement {query_lower} is not supported by datafusion, error {e}, statement ignored");
188+
let query = statement.to_string();
189+
if let Err(e) = session_context.sql(&query).await {
190+
warn!("SET statement {query} is not supported by datafusion, error {e}, statement ignored");
202191
}
203192

204193
// Always return SET success
205-
Ok(Response::Execution(Tag::new("SET")))
206-
};
207-
208-
Some(result)
209-
} else {
210-
None
194+
Some(Ok(Response::Execution(Tag::new("SET"))))
195+
}
211196
}
212197
}
213198

@@ -266,6 +251,8 @@ where
266251
mod tests {
267252
use std::time::Duration;
268253

254+
use datafusion::sql::sqlparser::{dialect::PostgreSqlDialect, parser::Parser};
255+
269256
use super::*;
270257
use crate::testing::MockClient;
271258

@@ -275,12 +262,13 @@ mod tests {
275262
let mut client = MockClient::new();
276263

277264
// Test setting timeout to 5000ms
278-
let set_response = try_respond_set_statements(
279-
&mut client,
280-
"set statement_timeout '5000ms'",
281-
&session_context,
282-
)
283-
.await;
265+
let statement = Parser::new(&PostgreSqlDialect {})
266+
.try_with_sql("set statement_timeout to '5000ms'")
267+
.unwrap()
268+
.parse_statement()
269+
.unwrap();
270+
let set_response =
271+
try_respond_set_statements(&mut client, &statement, &session_context).await;
284272

285273
assert!(set_response.is_some());
286274
assert!(set_response.unwrap().is_ok());
@@ -303,19 +291,22 @@ mod tests {
303291
let mut client = MockClient::new();
304292

305293
// Set timeout first
306-
let resp = try_respond_set_statements(
307-
&mut client,
308-
"set statement_timeout '1000ms'",
309-
&session_context,
310-
)
311-
.await;
294+
let statement = Parser::new(&PostgreSqlDialect {})
295+
.try_with_sql("set statement_timeout to '1000ms'")
296+
.unwrap()
297+
.parse_statement()
298+
.unwrap();
299+
let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
312300
assert!(resp.is_some());
313301
assert!(resp.unwrap().is_ok());
314302

315303
// Disable timeout with 0
316-
let resp =
317-
try_respond_set_statements(&mut client, "set statement_timeout '0'", &session_context)
318-
.await;
304+
let statement = Parser::new(&PostgreSqlDialect {})
305+
.try_with_sql("set statement_timeout to '0'")
306+
.unwrap()
307+
.parse_statement()
308+
.unwrap();
309+
let resp = try_respond_set_statements(&mut client, &statement, &session_context).await;
319310
assert!(resp.is_some());
320311
assert!(resp.unwrap().is_ok());
321312

0 commit comments

Comments
 (0)