|
| 1 | +use datafusion::sql::sqlparser::ast::Expr; |
| 2 | +use datafusion::sql::sqlparser::ast::Ident; |
| 3 | +use datafusion::sql::sqlparser::ast::Select; |
| 4 | +use datafusion::sql::sqlparser::ast::SelectItem; |
| 5 | +use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind; |
| 6 | +use datafusion::sql::sqlparser::ast::SetExpr; |
| 7 | +use datafusion::sql::sqlparser::ast::Statement; |
| 8 | +use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; |
| 9 | +use datafusion::sql::sqlparser::parser::Parser; |
| 10 | +use datafusion::sql::sqlparser::parser::ParserError; |
| 11 | + |
| 12 | +pub fn parse(sql: &str) -> Result<Vec<Statement>, ParserError> { |
| 13 | + let dialect = PostgreSqlDialect {}; |
| 14 | + |
| 15 | + Parser::parse_sql(&dialect, sql) |
| 16 | +} |
| 17 | + |
| 18 | +pub fn rewrite(mut s: Statement, rules: &[Box<dyn SqlStatementRewriteRule>]) -> Statement { |
| 19 | + for rule in rules { |
| 20 | + s = rule.rewrite(s); |
| 21 | + } |
| 22 | + |
| 23 | + s |
| 24 | +} |
| 25 | + |
| 26 | +pub trait SqlStatementRewriteRule { |
| 27 | + fn rewrite(&self, s: Statement) -> Statement; |
| 28 | +} |
| 29 | + |
| 30 | +/// Rewrite rule for adding alias to duplicated projection |
| 31 | +/// |
| 32 | +/// This rule is to deal with sql like `SELECT n.oid, n.* FROM n`, which is a |
| 33 | +/// valid statement in postgres. But datafusion treat it as illegal because of |
| 34 | +/// duplicated column oid in projection. |
| 35 | +/// |
| 36 | +/// This rule will add alias to column, when there is a wildcard found in |
| 37 | +/// projection. |
| 38 | +struct AliasDuplicatedProjectionRewrite; |
| 39 | + |
| 40 | +impl AliasDuplicatedProjectionRewrite { |
| 41 | + // Rewrites a SELECT statement to alias explicit columns from the same table as a qualified wildcard. |
| 42 | + fn rewrite_select_with_alias(select: &mut Box<Select>) { |
| 43 | + // 1. Collect all table aliases from qualified wildcards. |
| 44 | + let mut wildcard_tables = Vec::new(); |
| 45 | + let mut has_simple_wildcard = false; |
| 46 | + for p in &select.projection { |
| 47 | + match p { |
| 48 | + SelectItem::QualifiedWildcard(name, _) => match name { |
| 49 | + SelectItemQualifiedWildcardKind::ObjectName(objname) => { |
| 50 | + // for n.oid, |
| 51 | + let idents = objname |
| 52 | + .0 |
| 53 | + .iter() |
| 54 | + .map(|v| v.as_ident().unwrap().value.clone()) |
| 55 | + .collect::<Vec<_>>() |
| 56 | + .join("."); |
| 57 | + |
| 58 | + wildcard_tables.push(idents); |
| 59 | + } |
| 60 | + SelectItemQualifiedWildcardKind::Expr(_expr) => { |
| 61 | + // FIXME: |
| 62 | + } |
| 63 | + }, |
| 64 | + SelectItem::Wildcard(_) => { |
| 65 | + has_simple_wildcard = true; |
| 66 | + } |
| 67 | + _ => {} |
| 68 | + } |
| 69 | + } |
| 70 | + |
| 71 | + // If there are no qualified wildcards, there's nothing to do. |
| 72 | + if wildcard_tables.is_empty() && !has_simple_wildcard { |
| 73 | + return; |
| 74 | + } |
| 75 | + |
| 76 | + // 2. Rewrite the projection, adding aliases to matching columns. |
| 77 | + let mut new_projection = vec![]; |
| 78 | + for p in select.projection.drain(..) { |
| 79 | + match p { |
| 80 | + SelectItem::UnnamedExpr(expr) => { |
| 81 | + let alias_partial = match &expr { |
| 82 | + // Case for `oid` (unqualified identifier) |
| 83 | + Expr::Identifier(ident) => Some(ident.clone()), |
| 84 | + // Case for `n.oid` (compound identifier) |
| 85 | + Expr::CompoundIdentifier(idents) => { |
| 86 | + // compare every ident but the last |
| 87 | + if idents.len() > 1 { |
| 88 | + let table_name = &idents[..idents.len() - 1] |
| 89 | + .iter() |
| 90 | + .map(|i| i.value.clone()) |
| 91 | + .collect::<Vec<_>>() |
| 92 | + .join("."); |
| 93 | + if wildcard_tables.iter().any(|name| name == table_name) { |
| 94 | + Some(idents[idents.len() - 1].clone()) |
| 95 | + } else { |
| 96 | + None |
| 97 | + } |
| 98 | + } else { |
| 99 | + None |
| 100 | + } |
| 101 | + } |
| 102 | + _ => None, |
| 103 | + }; |
| 104 | + |
| 105 | + if let Some(name) = alias_partial { |
| 106 | + let alias = format!("__alias_{}", name); |
| 107 | + new_projection.push(SelectItem::ExprWithAlias { |
| 108 | + expr, |
| 109 | + alias: Ident::new(alias), |
| 110 | + }); |
| 111 | + } else { |
| 112 | + new_projection.push(SelectItem::UnnamedExpr(expr)); |
| 113 | + } |
| 114 | + } |
| 115 | + // Preserve existing aliases and wildcards. |
| 116 | + _ => new_projection.push(p), |
| 117 | + } |
| 118 | + } |
| 119 | + select.projection = new_projection; |
| 120 | + } |
| 121 | +} |
| 122 | + |
| 123 | +impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite { |
| 124 | + fn rewrite(&self, mut statement: Statement) -> Statement { |
| 125 | + if let Statement::Query(query) = &mut statement { |
| 126 | + if let SetExpr::Select(select) = query.body.as_mut() { |
| 127 | + Self::rewrite_select_with_alias(select); |
| 128 | + } |
| 129 | + } |
| 130 | + |
| 131 | + statement |
| 132 | + } |
| 133 | +} |
| 134 | + |
| 135 | +#[cfg(test)] |
| 136 | +mod tests { |
| 137 | + use super::*; |
| 138 | + |
| 139 | + #[test] |
| 140 | + fn test_alias_rewrite() { |
| 141 | + let rules: Vec<Box<dyn SqlStatementRewriteRule>> = |
| 142 | + vec![Box::new(AliasDuplicatedProjectionRewrite)]; |
| 143 | + |
| 144 | + let sql = "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n"; |
| 145 | + let statement = parse(sql).expect("Failed to parse").remove(0); |
| 146 | + |
| 147 | + let statement = rewrite(statement, &rules); |
| 148 | + assert_eq!( |
| 149 | + statement.to_string(), |
| 150 | + "SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n" |
| 151 | + ); |
| 152 | + |
| 153 | + let sql = "SELECT oid, * FROM pg_catalog.pg_namespace"; |
| 154 | + let statement = parse(sql).expect("Failed to parse").remove(0); |
| 155 | + |
| 156 | + let statement = rewrite(statement, &rules); |
| 157 | + assert_eq!( |
| 158 | + statement.to_string(), |
| 159 | + "SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace" |
| 160 | + ); |
| 161 | + |
| 162 | + let sql = "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"; |
| 163 | + let statement = parse(sql).expect("Failed to parse").remove(0); |
| 164 | + |
| 165 | + let statement = rewrite(statement, &rules); |
| 166 | + assert_eq!( |
| 167 | + statement.to_string(), |
| 168 | + "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id" |
| 169 | + ); |
| 170 | + } |
| 171 | +} |
0 commit comments