Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 89 additions & 50 deletions datafusion-postgres/src/hooks/set_show.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ impl QueryHook for SetShowHook {
try_respond_set_statements(client, statement, session_context).await
}
Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
let query = statement.to_string();
let query_lower = query.to_lowercase();

try_respond_show_statements(client, &query_lower, session_context).await
try_respond_show_statements(client, statement, session_context).await
}
_ => None,
}
Expand Down Expand Up @@ -93,10 +90,7 @@ impl QueryHook for SetShowHook {
try_respond_set_statements(client, statement, session_context).await
}
Statement::ShowVariable { .. } | Statement::ShowStatus { .. } => {
let query = statement.to_string();
let query_lower = query.to_lowercase();

try_respond_show_statements(client, &query_lower, session_context).await
try_respond_show_statements(client, statement, session_context).await
}
_ => None,
}
Expand Down Expand Up @@ -198,52 +192,56 @@ where

async fn try_respond_show_statements<C>(
client: &C,
query_lower: &str,
statement: &Statement,
session_context: &SessionContext,
) -> Option<PgWireResult<Response>>
where
C: ClientInfo + ?Sized,
{
if query_lower.starts_with("show ") {
let result = match query_lower.strip_suffix(";").unwrap_or(query_lower) {
"show time zone" => {
let timezone = client::get_timezone(client).unwrap_or("UTC");
mock_show_response("TimeZone", timezone).map(Response::Query)
}
"show server_version" => {
mock_show_response("server_version", "15.0 (DataFusion)").map(Response::Query)
}
"show transaction_isolation" => {
mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query)
}
"show catalogs" => {
let catalogs = session_context.catalog_names();
let value = catalogs.join(", ");
mock_show_response("Catalogs", &value).map(Response::Query)
}
"show search_path" => {
let default_schema = "public";
mock_show_response("search_path", default_schema).map(Response::Query)
}
"show statement_timeout" => {
let timeout = client::get_statement_timeout(client);
let timeout_str = match timeout {
Some(duration) => format!("{}ms", duration.as_millis()),
None => "0".to_string(),
};
mock_show_response("statement_timeout", &timeout_str).map(Response::Query)
}
"show transaction isolation level" => {
mock_show_response("transaction_isolation", "read_committed").map(Response::Query)
}
_ => {
info!("Unsupported show statement: {query_lower}");
mock_show_response("unsupported_show_statement", "").map(Response::Query)
}
};
Some(result)
} else {
None
let Statement::ShowVariable { variable } = statement else {
return None;
};

let variables = variable
.iter()
.map(|v| &v.value as &str)
.collect::<Vec<_>>();

match &variables as &[&str] {
["time", "zone"] => {
let timezone = client::get_timezone(client).unwrap_or("UTC");
Some(mock_show_response("TimeZone", timezone).map(Response::Query))
}
["server_version"] => {
Some(mock_show_response("server_version", "15.0 (DataFusion)").map(Response::Query))
}
["transaction_isolation"] => Some(
mock_show_response("transaction_isolation", "read uncommitted").map(Response::Query),
),
["catalogs"] => {
let catalogs = session_context.catalog_names();
let value = catalogs.join(", ");
Some(mock_show_response("Catalogs", &value).map(Response::Query))
}
["search_path"] => {
let default_schema = "public";
Some(mock_show_response("search_path", default_schema).map(Response::Query))
}
["statement_timeout"] => {
let timeout = client::get_statement_timeout(client);
let timeout_str = match timeout {
Some(duration) => format!("{}ms", duration.as_millis()),
None => "0".to_string(),
};
Some(mock_show_response("statement_timeout", &timeout_str).map(Response::Query))
}
["transaction", "isolation", "level"] => {
Some(mock_show_response("transaction_isolation", "read_committed").map(Response::Query))
}
_ => {
info!("Unsupported show statement: {}", statement);
Some(mock_show_response("unsupported_show_statement", "").map(Response::Query))
}
}
}

Expand Down Expand Up @@ -278,8 +276,13 @@ mod tests {
assert_eq!(timeout, Some(Duration::from_millis(5000)));

// Test SHOW statement_timeout
let statement = Parser::new(&PostgreSqlDialect {})
.try_with_sql("show statement_timeout")
.unwrap()
.parse_statement()
.unwrap();
let show_response =
try_respond_show_statements(&client, "show statement_timeout", &session_context).await;
try_respond_show_statements(&client, &statement, &session_context).await;

assert!(show_response.is_some());
assert!(show_response.unwrap().is_ok());
Expand Down Expand Up @@ -313,4 +316,40 @@ mod tests {
let timeout = client::get_statement_timeout(&client);
assert_eq!(timeout, None);
}

#[tokio::test]
async fn test_supported_show_statements_returned_columns() {
let session_context = SessionContext::new();
let client = MockClient::new();

let tests = [
("show time zone", "TimeZone"),
("show server_version", "server_version"),
("show transaction_isolation", "transaction_isolation"),
("show catalogs", "Catalogs"),
("show search_path", "search_path"),
("show statement_timeout", "statement_timeout"),
("show transaction isolation level", "transaction_isolation"),
];

for (query, expected_response_col) in tests {
let statement = Parser::new(&PostgreSqlDialect {})
.try_with_sql(&query)
.unwrap()
.parse_statement()
.unwrap();
let show_response =
try_respond_show_statements(&client, &statement, &session_context).await;

let Some(Ok(Response::Query(show_response))) = show_response else {
panic!("unexpected show response");
};

assert_eq!(show_response.command_tag(), "SELECT");

let row_schema = show_response.row_schema();
assert_eq!(row_schema.len(), 1);
assert_eq!(row_schema[0].name(), expected_response_col);
}
}
}
Loading