|
| 1 | +use std::collections::HashSet; |
1 | 2 | use std::sync::Arc; |
2 | 3 |
|
3 | 4 | use datafusion::sql::sqlparser::ast::Expr; |
4 | 5 | use datafusion::sql::sqlparser::ast::Ident; |
| 6 | +use datafusion::sql::sqlparser::ast::OrderByKind; |
| 7 | +use datafusion::sql::sqlparser::ast::Query; |
5 | 8 | use datafusion::sql::sqlparser::ast::Select; |
6 | 9 | use datafusion::sql::sqlparser::ast::SelectItem; |
7 | 10 | use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind; |
8 | 11 | use datafusion::sql::sqlparser::ast::SetExpr; |
9 | 12 | use datafusion::sql::sqlparser::ast::Statement; |
| 13 | +use datafusion::sql::sqlparser::ast::TableFactor; |
| 14 | +use datafusion::sql::sqlparser::ast::TableWithJoins; |
| 15 | +use datafusion::sql::sqlparser::ast::Value; |
10 | 16 | use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; |
11 | 17 | use datafusion::sql::sqlparser::parser::Parser; |
12 | 18 | use datafusion::sql::sqlparser::parser::ParserError; |
@@ -135,6 +141,188 @@ impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite { |
135 | 141 | } |
136 | 142 | } |
137 | 143 |
|
| 144 | +/// Prepend qualifier for order by or filter when there is qualified wildcard |
| 145 | +/// |
| 146 | +/// Postgres allows unqualified identifier in ORDER BY and FILTER but it's not |
| 147 | +/// accepted by datafusion. |
| 148 | +#[derive(Debug)] |
| 149 | +pub struct ResolveUnqualifiedIdentifer; |
| 150 | + |
| 151 | +impl ResolveUnqualifiedIdentifer { |
| 152 | + fn rewrite_unqualified_identifiers(query: &mut Box<Query>) { |
| 153 | + if let SetExpr::Select(select) = query.body.as_mut() { |
| 154 | + // Step 1: Find all table aliases from FROM and JOIN clauses. |
| 155 | + let table_aliases = Self::get_table_aliases(&select.from); |
| 156 | + |
| 157 | + // Step 2: Check for a single qualified wildcard in the projection. |
| 158 | + let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection); |
| 159 | + if qualified_wildcard_alias.is_none() || table_aliases.is_empty() { |
| 160 | + return; // Conditions not met. |
| 161 | + } |
| 162 | + |
| 163 | + let wildcard_alias = qualified_wildcard_alias.unwrap(); |
| 164 | + |
| 165 | + // Step 3: Rewrite expressions in the WHERE and ORDER BY clauses. |
| 166 | + if let Some(selection) = &mut select.selection { |
| 167 | + Self::rewrite_expr(selection, &wildcard_alias, &table_aliases); |
| 168 | + } |
| 169 | + |
| 170 | + if let Some(OrderByKind::Expressions(order_by_exprs)) = |
| 171 | + query.order_by.as_mut().map(|o| &mut o.kind) |
| 172 | + { |
| 173 | + for order_by_expr in order_by_exprs { |
| 174 | + Self::rewrite_expr(&mut order_by_expr.expr, &wildcard_alias, &table_aliases); |
| 175 | + } |
| 176 | + } |
| 177 | + } |
| 178 | + } |
| 179 | + |
| 180 | + fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet<String> { |
| 181 | + let mut aliases = HashSet::new(); |
| 182 | + for table_with_joins in tables { |
| 183 | + if let TableFactor::Table { |
| 184 | + alias: Some(alias), .. |
| 185 | + } = &table_with_joins.relation |
| 186 | + { |
| 187 | + aliases.insert(alias.name.value.clone()); |
| 188 | + } |
| 189 | + for join in &table_with_joins.joins { |
| 190 | + if let TableFactor::Table { |
| 191 | + alias: Some(alias), .. |
| 192 | + } = &join.relation |
| 193 | + { |
| 194 | + aliases.insert(alias.name.value.clone()); |
| 195 | + } |
| 196 | + } |
| 197 | + } |
| 198 | + aliases |
| 199 | + } |
| 200 | + |
| 201 | + fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option<String> { |
| 202 | + let mut qualified_wildcards = projection |
| 203 | + .iter() |
| 204 | + .filter_map(|item| { |
| 205 | + if let SelectItem::QualifiedWildcard( |
| 206 | + SelectItemQualifiedWildcardKind::ObjectName(objname), |
| 207 | + _, |
| 208 | + ) = item |
| 209 | + { |
| 210 | + Some( |
| 211 | + objname |
| 212 | + .0 |
| 213 | + .iter() |
| 214 | + .map(|v| v.as_ident().unwrap().value.clone()) |
| 215 | + .collect::<Vec<_>>() |
| 216 | + .join("."), |
| 217 | + ) |
| 218 | + } else { |
| 219 | + None |
| 220 | + } |
| 221 | + }) |
| 222 | + .collect::<Vec<_>>(); |
| 223 | + |
| 224 | + if qualified_wildcards.len() == 1 { |
| 225 | + Some(qualified_wildcards.remove(0)) |
| 226 | + } else { |
| 227 | + None |
| 228 | + } |
| 229 | + } |
| 230 | + |
| 231 | + fn rewrite_expr(expr: &mut Expr, wildcard_alias: &str, table_aliases: &HashSet<String>) { |
| 232 | + match expr { |
| 233 | + Expr::Identifier(ident) => { |
| 234 | + // If the identifier is not a table alias itself, rewrite it. |
| 235 | + if !table_aliases.contains(&ident.value) { |
| 236 | + *expr = Expr::CompoundIdentifier(vec![ |
| 237 | + Ident::new(wildcard_alias.to_string()), |
| 238 | + ident.clone(), |
| 239 | + ]); |
| 240 | + } |
| 241 | + } |
| 242 | + Expr::BinaryOp { left, right, .. } => { |
| 243 | + Self::rewrite_expr(left, wildcard_alias, table_aliases); |
| 244 | + Self::rewrite_expr(right, wildcard_alias, table_aliases); |
| 245 | + } |
| 246 | + // Add more cases for other expression types as needed (e.g., `InList`, `Between`, etc.) |
| 247 | + _ => {} |
| 248 | + } |
| 249 | + } |
| 250 | +} |
| 251 | + |
| 252 | +impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer { |
| 253 | + fn rewrite(&self, mut statement: Statement) -> Statement { |
| 254 | + if let Statement::Query(query) = &mut statement { |
| 255 | + Self::rewrite_unqualified_identifiers(query); |
| 256 | + } |
| 257 | + |
| 258 | + statement |
| 259 | + } |
| 260 | +} |
| 261 | + |
| 262 | +/// Remove datafusion unsupported type annotations |
| 263 | +#[derive(Debug)] |
| 264 | +pub struct RemoveUnsupportedTypes { |
| 265 | + unsupported_types: HashSet<String>, |
| 266 | +} |
| 267 | + |
| 268 | +impl RemoveUnsupportedTypes { |
| 269 | + pub fn new() -> Self { |
| 270 | + let mut unsupported_types = HashSet::new(); |
| 271 | + unsupported_types.insert("regclass".to_owned()); |
| 272 | + |
| 273 | + Self { unsupported_types } |
| 274 | + } |
| 275 | + |
| 276 | + fn rewrite_expr_unsupported_types(&self, expr: &mut Expr) { |
| 277 | + match expr { |
| 278 | + // This is the key part: identify constants with type annotations. |
| 279 | + Expr::TypedString { value, data_type } => { |
| 280 | + if self |
| 281 | + .unsupported_types |
| 282 | + .contains(data_type.to_string().to_lowercase().as_str()) |
| 283 | + { |
| 284 | + *expr = |
| 285 | + Expr::Value(Value::SingleQuotedString(value.to_string()).with_empty_span()); |
| 286 | + } |
| 287 | + } |
| 288 | + Expr::Cast { |
| 289 | + data_type, |
| 290 | + expr: value, |
| 291 | + .. |
| 292 | + } => { |
| 293 | + if self |
| 294 | + .unsupported_types |
| 295 | + .contains(data_type.to_string().to_lowercase().as_str()) |
| 296 | + { |
| 297 | + *expr = *value.clone(); |
| 298 | + } |
| 299 | + } |
| 300 | + // Handle binary operations by recursively rewriting both sides. |
| 301 | + Expr::BinaryOp { left, right, .. } => { |
| 302 | + self.rewrite_expr_unsupported_types(left); |
| 303 | + self.rewrite_expr_unsupported_types(right); |
| 304 | + } |
| 305 | + // Add more match arms for other expression types (e.g., `Function`, `InList`) as needed. |
| 306 | + _ => {} |
| 307 | + } |
| 308 | + } |
| 309 | +} |
| 310 | + |
| 311 | +impl SqlStatementRewriteRule for RemoveUnsupportedTypes { |
| 312 | + fn rewrite(&self, mut s: Statement) -> Statement { |
| 313 | + // Traverse the AST to find the WHERE clause and rewrite it. |
| 314 | + if let Statement::Query(query) = &mut s { |
| 315 | + if let SetExpr::Select(select) = query.body.as_mut() { |
| 316 | + if let Some(expr) = &mut select.selection { |
| 317 | + self.rewrite_expr_unsupported_types(expr); |
| 318 | + } |
| 319 | + } |
| 320 | + } |
| 321 | + |
| 322 | + s |
| 323 | + } |
| 324 | +} |
| 325 | + |
138 | 326 | #[cfg(test)] |
139 | 327 | mod tests { |
140 | 328 | use super::*; |
@@ -180,4 +368,61 @@ mod tests { |
180 | 368 | "SELECT n.oid AS __alias_oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspsname" |
181 | 369 | ); |
182 | 370 | } |
| 371 | + |
| 372 | + #[test] |
| 373 | + fn test_qualifier_prepend() { |
| 374 | + let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = |
| 375 | + vec![Arc::new(ResolveUnqualifiedIdentifer)]; |
| 376 | + |
| 377 | + let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname"; |
| 378 | + let statement = parse(sql).expect("Failed to parse").remove(0); |
| 379 | + |
| 380 | + let statement = rewrite(statement, &rules); |
| 381 | + assert_eq!( |
| 382 | + statement.to_string(), |
| 383 | + "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" |
| 384 | + ); |
| 385 | + |
| 386 | + let sql = "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"; |
| 387 | + let statement = parse(sql).expect("Failed to parse").remove(0); |
| 388 | + |
| 389 | + let statement = rewrite(statement, &rules); |
| 390 | + assert_eq!( |
| 391 | + statement.to_string(), |
| 392 | + "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname" |
| 393 | + ); |
| 394 | + |
| 395 | + let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname"; |
| 396 | + let statement = parse(sql).expect("Failed to parse").remove(0); |
| 397 | + |
| 398 | + let statement = rewrite(statement, &rules); |
| 399 | + assert_eq!( |
| 400 | + statement.to_string(), |
| 401 | + "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname" |
| 402 | + ); |
| 403 | + } |
| 404 | + |
| 405 | + #[test] |
| 406 | + fn test_remove_unsupported_types() { |
| 407 | + let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = |
| 408 | + vec![Arc::new(RemoveUnsupportedTypes::new())]; |
| 409 | + |
| 410 | + let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname"; |
| 411 | + let statement = parse(sql).expect("Failed to parse").remove(0); |
| 412 | + |
| 413 | + let statement = rewrite(statement, &rules); |
| 414 | + assert_eq!( |
| 415 | + statement.to_string(), |
| 416 | + "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" |
| 417 | + ); |
| 418 | + |
| 419 | + let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"; |
| 420 | + let statement = parse(sql).expect("Failed to parse").remove(0); |
| 421 | + |
| 422 | + let statement = rewrite(statement, &rules); |
| 423 | + assert_eq!( |
| 424 | + statement.to_string(), |
| 425 | + "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" |
| 426 | + ); |
| 427 | + } |
183 | 428 | } |
0 commit comments