|
| 1 | +use std::collections::HashSet; |
1 | 2 | use std::sync::Arc; |
2 | 3 |
|
3 | 4 | use datafusion::sql::sqlparser::ast::Expr; |
4 | 5 | use datafusion::sql::sqlparser::ast::Ident; |
| 6 | +use datafusion::sql::sqlparser::ast::OrderByKind; |
| 7 | +use datafusion::sql::sqlparser::ast::Query; |
5 | 8 | use datafusion::sql::sqlparser::ast::Select; |
6 | 9 | use datafusion::sql::sqlparser::ast::SelectItem; |
7 | 10 | use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind; |
8 | 11 | use datafusion::sql::sqlparser::ast::SetExpr; |
9 | 12 | use datafusion::sql::sqlparser::ast::Statement; |
| 13 | +use datafusion::sql::sqlparser::ast::TableFactor; |
| 14 | +use datafusion::sql::sqlparser::ast::TableWithJoins; |
10 | 15 | use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; |
11 | 16 | use datafusion::sql::sqlparser::parser::Parser; |
12 | 17 | use datafusion::sql::sqlparser::parser::ParserError; |
@@ -135,6 +140,124 @@ impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite { |
135 | 140 | } |
136 | 141 | } |
137 | 142 |
|
| 143 | +/// Prepend qualifier for order by or filter when there is qualified wildcard |
| 144 | +/// |
| 145 | +/// Postgres allows unqualified identifier in ORDER BY and FILTER but it's not |
| 146 | +/// accepted by datafusion. |
| 147 | +#[derive(Debug)] |
| 148 | +pub struct ResolveUnqualifiedIdentifer; |
| 149 | + |
| 150 | +impl ResolveUnqualifiedIdentifer { |
| 151 | + fn rewrite_unqualified_identifiers(query: &mut Box<Query>) { |
| 152 | + if let SetExpr::Select(select) = query.body.as_mut() { |
| 153 | + // Step 1: Find all table aliases from FROM and JOIN clauses. |
| 154 | + let table_aliases = Self::get_table_aliases(&select.from); |
| 155 | + |
| 156 | + // Step 2: Check for a single qualified wildcard in the projection. |
| 157 | + let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection); |
| 158 | + if qualified_wildcard_alias.is_none() || table_aliases.is_empty() { |
| 159 | + return; // Conditions not met. |
| 160 | + } |
| 161 | + |
| 162 | + let wildcard_alias = qualified_wildcard_alias.unwrap(); |
| 163 | + |
| 164 | + // Step 3: Rewrite expressions in the WHERE and ORDER BY clauses. |
| 165 | + if let Some(selection) = &mut select.selection { |
| 166 | + Self::rewrite_expr(selection, &wildcard_alias, &table_aliases); |
| 167 | + } |
| 168 | + |
| 169 | + if let Some(OrderByKind::Expressions(order_by_exprs)) = |
| 170 | + query.order_by.as_mut().map(|o| &mut o.kind) |
| 171 | + { |
| 172 | + for order_by_expr in order_by_exprs { |
| 173 | + Self::rewrite_expr(&mut order_by_expr.expr, &wildcard_alias, &table_aliases); |
| 174 | + } |
| 175 | + } |
| 176 | + } |
| 177 | + } |
| 178 | + |
| 179 | + fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet<String> { |
| 180 | + let mut aliases = HashSet::new(); |
| 181 | + for table_with_joins in tables { |
| 182 | + if let TableFactor::Table { |
| 183 | + alias: Some(alias), .. |
| 184 | + } = &table_with_joins.relation |
| 185 | + { |
| 186 | + aliases.insert(alias.name.value.clone()); |
| 187 | + } |
| 188 | + for join in &table_with_joins.joins { |
| 189 | + if let TableFactor::Table { |
| 190 | + alias: Some(alias), .. |
| 191 | + } = &join.relation |
| 192 | + { |
| 193 | + aliases.insert(alias.name.value.clone()); |
| 194 | + } |
| 195 | + } |
| 196 | + } |
| 197 | + aliases |
| 198 | + } |
| 199 | + |
| 200 | + fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option<String> { |
| 201 | + let mut qualified_wildcards = projection |
| 202 | + .iter() |
| 203 | + .filter_map(|item| { |
| 204 | + if let SelectItem::QualifiedWildcard( |
| 205 | + SelectItemQualifiedWildcardKind::ObjectName(objname), |
| 206 | + _, |
| 207 | + ) = item |
| 208 | + { |
| 209 | + Some( |
| 210 | + objname |
| 211 | + .0 |
| 212 | + .iter() |
| 213 | + .map(|v| v.as_ident().unwrap().value.clone()) |
| 214 | + .collect::<Vec<_>>() |
| 215 | + .join("."), |
| 216 | + ) |
| 217 | + } else { |
| 218 | + None |
| 219 | + } |
| 220 | + }) |
| 221 | + .collect::<Vec<_>>(); |
| 222 | + |
| 223 | + if qualified_wildcards.len() == 1 { |
| 224 | + Some(qualified_wildcards.remove(0)) |
| 225 | + } else { |
| 226 | + None |
| 227 | + } |
| 228 | + } |
| 229 | + |
| 230 | + fn rewrite_expr(expr: &mut Expr, wildcard_alias: &str, table_aliases: &HashSet<String>) { |
| 231 | + match expr { |
| 232 | + Expr::Identifier(ident) => { |
| 233 | + // If the identifier is not a table alias itself, rewrite it. |
| 234 | + if !table_aliases.contains(&ident.value) { |
| 235 | + *expr = Expr::CompoundIdentifier(vec![ |
| 236 | + Ident::new(wildcard_alias.to_string()), |
| 237 | + ident.clone(), |
| 238 | + ]); |
| 239 | + } |
| 240 | + } |
| 241 | + Expr::BinaryOp { left, right, .. } => { |
| 242 | + Self::rewrite_expr(left, wildcard_alias, table_aliases); |
| 243 | + Self::rewrite_expr(right, wildcard_alias, table_aliases); |
| 244 | + } |
| 245 | + // Add more cases for other expression types as needed (e.g., `InList`, `Between`, etc.) |
| 246 | + _ => {} |
| 247 | + } |
| 248 | + } |
| 249 | +} |
| 250 | + |
| 251 | +impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer { |
| 252 | + fn rewrite(&self, mut statement: Statement) -> Statement { |
| 253 | + if let Statement::Query(query) = &mut statement { |
| 254 | + Self::rewrite_unqualified_identifiers(query); |
| 255 | + } |
| 256 | + |
| 257 | + statement |
| 258 | + } |
| 259 | +} |
| 260 | + |
138 | 261 | #[cfg(test)] |
139 | 262 | mod tests { |
140 | 263 | use super::*; |
@@ -171,4 +294,37 @@ mod tests { |
171 | 294 | "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id" |
172 | 295 | ); |
173 | 296 | } |
| 297 | + |
| 298 | + #[test] |
| 299 | + fn test_qualifier_prepend() { |
| 300 | + let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = |
| 301 | + vec![Arc::new(ResolveUnqualifiedIdentifer)]; |
| 302 | + |
| 303 | + let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname"; |
| 304 | + let statement = parse(sql).expect("Failed to parse").remove(0); |
| 305 | + |
| 306 | + let statement = rewrite(statement, &rules); |
| 307 | + assert_eq!( |
| 308 | + statement.to_string(), |
| 309 | + "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" |
| 310 | + ); |
| 311 | + |
| 312 | + let sql = "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"; |
| 313 | + let statement = parse(sql).expect("Failed to parse").remove(0); |
| 314 | + |
| 315 | + let statement = rewrite(statement, &rules); |
| 316 | + assert_eq!( |
| 317 | + statement.to_string(), |
| 318 | + "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname" |
| 319 | + ); |
| 320 | + |
| 321 | + let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname"; |
| 322 | + let statement = parse(sql).expect("Failed to parse").remove(0); |
| 323 | + |
| 324 | + let statement = rewrite(statement, &rules); |
| 325 | + assert_eq!( |
| 326 | + statement.to_string(), |
| 327 | + "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname" |
| 328 | + ); |
| 329 | + } |
174 | 330 | } |
0 commit comments