Skip to content

Commit d8dbf08

Browse files
authored
Rework SHOW handling (#213)
Rework SHOW handling to use ast values instead of custom string parsing.
1 parent 21229ab commit d8dbf08

File tree

1 file changed

+89
-50
lines changed

1 file changed

+89
-50
lines changed

datafusion-postgres/src/hooks/set_show.rs

Lines changed: 89 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,7 @@ impl QueryHook for SetShowHook {
3232
try_respond_set_statements(client, statement, session_context).await
3333
}
3434
Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
35-
let query = statement.to_string();
36-
let query_lower = query.to_lowercase();
37-
38-
try_respond_show_statements(client, &query_lower, session_context).await
35+
try_respond_show_statements(client, statement, session_context).await
3936
}
4037
_ => None,
4138
}
@@ -93,10 +90,7 @@ impl QueryHook for SetShowHook {
9390
try_respond_set_statements(client, statement, session_context).await
9491
}
9592
Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
96-
let query = statement.to_string();
97-
let query_lower = query.to_lowercase();
98-
99-
try_respond_show_statements(client, &query_lower, session_context).await
93+
try_respond_show_statements(client, statement, session_context).await
10094
}
10195
_ => None,
10296
}
@@ -198,52 +192,56 @@ where
198192

199193
async fn try_respond_show_statements<C>(
200194
client: &C,
201-
query_lower: &str,
195+
statement: &Statement,
202196
session_context: &SessionContext,
203197
) -> Option<PgWireResult<Response>>
204198
where
205199
C: ClientInfo + ?Sized,
206200
{
207-
if query_lower.starts_with("show ") {
208-
let result = match query_lower.strip_suffix(";").unwrap_or(query_lower) {
209-
"show time zone" => {
210-
let timezone = client::get_timezone(client).unwrap_or("UTC");
211-
mock_show_response("TimeZone", timezone).map(Response::Query)
212-
}
213-
"show server_version" => {
214-
mock_show_response("server_version", "15.0 (DataFusion)").map(Response::Query)
215-
}
216-
"show transaction_isolation" => {
217-
mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query)
218-
}
219-
"show catalogs" => {
220-
let catalogs = session_context.catalog_names();
221-
let value = catalogs.join(", ");
222-
mock_show_response("Catalogs", &value).map(Response::Query)
223-
}
224-
"show search_path" => {
225-
let default_schema = "public";
226-
mock_show_response("search_path", default_schema).map(Response::Query)
227-
}
228-
"show statement_timeout" => {
229-
let timeout = client::get_statement_timeout(client);
230-
let timeout_str = match timeout {
231-
Some(duration) => format!("{}ms", duration.as_millis()),
232-
None => "0".to_string(),
233-
};
234-
mock_show_response("statement_timeout", &timeout_str).map(Response::Query)
235-
}
236-
"show transaction isolation level" => {
237-
mock_show_response("transaction_isolation", "read_committed").map(Response::Query)
238-
}
239-
_ => {
240-
info!("Unsupported show statement: {query_lower}");
241-
mock_show_response("unsupported_show_statement", "").map(Response::Query)
242-
}
243-
};
244-
Some(result)
245-
} else {
246-
None
201+
let Statement::ShowVariable { variable } = statement else {
202+
return None;
203+
};
204+
205+
let variables = variable
206+
.iter()
207+
.map(|v| &v.value as &str)
208+
.collect::<Vec<_>>();
209+
210+
match &variables as &[&str] {
211+
["time", "zone"] => {
212+
let timezone = client::get_timezone(client).unwrap_or("UTC");
213+
Some(mock_show_response("TimeZone", timezone).map(Response::Query))
214+
}
215+
["server_version"] => {
216+
Some(mock_show_response("server_version", "15.0 (DataFusion)").map(Response::Query))
217+
}
218+
["transaction_isolation"] => Some(
219+
mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query),
220+
),
221+
["catalogs"] => {
222+
let catalogs = session_context.catalog_names();
223+
let value = catalogs.join(", ");
224+
Some(mock_show_response("Catalogs", &value).map(Response::Query))
225+
}
226+
["search_path"] => {
227+
let default_schema = "public";
228+
Some(mock_show_response("search_path", default_schema).map(Response::Query))
229+
}
230+
["statement_timeout"] => {
231+
let timeout = client::get_statement_timeout(client);
232+
let timeout_str = match timeout {
233+
Some(duration) => format!("{}ms", duration.as_millis()),
234+
None => "0".to_string(),
235+
};
236+
Some(mock_show_response("statement_timeout", &timeout_str).map(Response::Query))
237+
}
238+
["transaction", "isolation", "level"] => {
239+
Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query))
240+
}
241+
_ => {
242+
info!("Unsupported show statement: {}", statement);
243+
Some(mock_show_response("unsupported_show_statement", "").map(Response::Query))
244+
}
247245
}
248246
}
249247

@@ -278,8 +276,13 @@ mod tests {
278276
assert_eq!(timeout, Some(Duration::from_millis(5000)));
279277

280278
// Test SHOW statement_timeout
279+
let statement = Parser::new(&PostgreSqlDialect {})
280+
.try_with_sql("show statement_timeout")
281+
.unwrap()
282+
.parse_statement()
283+
.unwrap();
281284
let show_response =
282-
try_respond_show_statements(&client, "show statement_timeout", &session_context).await;
285+
try_respond_show_statements(&client, &statement, &session_context).await;
283286

284287
assert!(show_response.is_some());
285288
assert!(show_response.unwrap().is_ok());
@@ -313,4 +316,40 @@ mod tests {
313316
let timeout = client::get_statement_timeout(&client);
314317
assert_eq!(timeout, None);
315318
}
319+
320+
#[tokio::test]
321+
async fn test_supported_show_statements_returned_columns() {
322+
let session_context = SessionContext::new();
323+
let client = MockClient::new();
324+
325+
let tests = [
326+
("show time zone", "TimeZone"),
327+
("show server_version", "server_version"),
328+
("show transaction_isolation", "transaction_isolation"),
329+
("show catalogs", "Catalogs"),
330+
("show search_path", "search_path"),
331+
("show statement_timeout", "statement_timeout"),
332+
("show transaction isolation level", "transaction_isolation"),
333+
];
334+
335+
for (query, expected_response_col) in tests {
336+
let statement = Parser::new(&PostgreSqlDialect {})
337+
.try_with_sql(&query)
338+
.unwrap()
339+
.parse_statement()
340+
.unwrap();
341+
let show_response =
342+
try_respond_show_statements(&client, &statement, &session_context).await;
343+
344+
let Some(Ok(Response::Query(show_response))) = show_response else {
345+
panic!("unexpected show response");
346+
};
347+
348+
assert_eq!(show_response.command_tag(), "SELECT");
349+
350+
let row_schema = show_response.row_schema();
351+
assert_eq!(row_schema.len(), 1);
352+
assert_eq!(row_schema[0].name(), expected_response_col);
353+
}
354+
}
316355
}

0 commit comments

Comments
 (0)