@@ -90,14 +90,22 @@ impl Column {
9090 /// For example, `foo` will be normalized to `t.foo` if there is a
9191 /// column named `foo` in a relation named `t` found in `schemas`
9292 pub fn normalize ( self , plan : & LogicalPlan ) -> Result < Self > {
93+ let schemas = plan. all_schemas ( ) ;
94+ let using_columns = plan. using_columns ( ) ?;
95+ self . normalize_with_schemas ( & schemas, & using_columns)
96+ }
97+
98+ // Internal implementation of normalize
99+ fn normalize_with_schemas (
100+ self ,
101+ schemas : & [ & Arc < DFSchema > ] ,
102+ using_columns : & [ HashSet < Column > ] ,
103+ ) -> Result < Self > {
93104 if self . relation . is_some ( ) {
94105 return Ok ( self ) ;
95106 }
96107
97- let schemas = plan. all_schemas ( ) ;
98- let using_columns = plan. using_columns ( ) ?;
99-
100- for schema in & schemas {
108+ for schema in schemas {
101109 let fields = schema. fields_with_unqualified_name ( & self . name ) ;
102110 match fields. len ( ) {
103111 0 => continue ,
@@ -118,7 +126,7 @@ impl Column {
118126 // We will use the relation from the first matched field to normalize self.
119127
120128 // Compare matched fields with one USING JOIN clause at a time
121- for using_col in & using_columns {
129+ for using_col in using_columns {
122130 let all_matched = fields
123131 . iter ( )
124132 . all ( |f| using_col. contains ( & f. qualified_column ( ) ) ) ;
@@ -1171,22 +1179,39 @@ pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<E
11711179
11721180/// Recursively call [`Column::normalize`] on all Column expressions
11731181/// in the `expr` expression tree.
1174- pub fn normalize_col ( e : Expr , plan : & LogicalPlan ) -> Result < Expr > {
1182+ pub fn normalize_col ( expr : Expr , plan : & LogicalPlan ) -> Result < Expr > {
1183+ normalize_col_with_schemas ( expr, & plan. all_schemas ( ) , & plan. using_columns ( ) ?)
1184+ }
1185+
1186+ /// Recursively call [`Column::normalize`] on all Column expressions
1187+ /// in the `expr` expression tree.
1188+ fn normalize_col_with_schemas (
1189+ expr : Expr ,
1190+ schemas : & [ & Arc < DFSchema > ] ,
1191+ using_columns : & [ HashSet < Column > ] ,
1192+ ) -> Result < Expr > {
11751193 struct ColumnNormalizer < ' a > {
1176- plan : & ' a LogicalPlan ,
1194+ schemas : & ' a [ & ' a Arc < DFSchema > ] ,
1195+ using_columns : & ' a [ HashSet < Column > ] ,
11771196 }
11781197
11791198 impl < ' a > ExprRewriter for ColumnNormalizer < ' a > {
11801199 fn mutate ( & mut self , expr : Expr ) -> Result < Expr > {
11811200 if let Expr :: Column ( c) = expr {
1182- Ok ( Expr :: Column ( c. normalize ( self . plan ) ?) )
1201+ Ok ( Expr :: Column ( c. normalize_with_schemas (
1202+ self . schemas ,
1203+ self . using_columns ,
1204+ ) ?) )
11831205 } else {
11841206 Ok ( expr)
11851207 }
11861208 }
11871209 }
11881210
1189- e. rewrite ( & mut ColumnNormalizer { plan } )
1211+ expr. rewrite ( & mut ColumnNormalizer {
1212+ schemas,
1213+ using_columns,
1214+ } )
11901215}
11911216
11921217/// Recursively normalize all Column expressions in a list of expression trees
@@ -1198,6 +1223,38 @@ pub fn normalize_cols(
11981223 exprs. into_iter ( ) . map ( |e| normalize_col ( e, plan) ) . collect ( )
11991224}
12001225
1226+ /// Recursively 'unnormalize' (remove all qualifiers) from an
1227+ /// expression tree.
1228+ ///
1229+ /// For example, if there were expressions like `foo.bar` this would
1230+ /// rewrite it to just `bar`.
1231+ pub fn unnormalize_col ( expr : Expr ) -> Expr {
1232+ struct RemoveQualifier { }
1233+
1234+ impl ExprRewriter for RemoveQualifier {
1235+ fn mutate ( & mut self , expr : Expr ) -> Result < Expr > {
1236+ if let Expr :: Column ( col) = expr {
1237+ //let Column { relation: _, name } = col;
1238+ Ok ( Expr :: Column ( Column {
1239+ relation : None ,
1240+ name : col. name ,
1241+ } ) )
1242+ } else {
1243+ Ok ( expr)
1244+ }
1245+ }
1246+ }
1247+
1248+ expr. rewrite ( & mut RemoveQualifier { } )
1249+ . expect ( "Unnormalize is infallable" )
1250+ }
1251+
1252+ /// Recursively un-normalize all Column expressions in a list of expression trees
1253+ #[ inline]
1254+ pub fn unnormalize_cols ( exprs : impl IntoIterator < Item = Expr > ) -> Vec < Expr > {
1255+ exprs. into_iter ( ) . map ( unnormalize_col) . collect ( )
1256+ }
1257+
12011258/// Create an expression to represent the min() aggregate function
12021259pub fn min ( expr : Expr ) -> Expr {
12031260 Expr :: AggregateFunction {
@@ -1810,4 +1867,78 @@ mod tests {
18101867 }
18111868 }
18121869 }
1870+
1871+ #[ test]
1872+ fn normalize_cols ( ) {
1873+ let expr = col ( "a" ) + col ( "b" ) + col ( "c" ) ;
1874+
1875+ // Schemas with some matching and some non matching cols
1876+ let schema_a =
1877+ DFSchema :: new ( vec ! [ make_field( "tableA" , "a" ) , make_field( "tableA" , "aa" ) ] )
1878+ . unwrap ( ) ;
1879+ let schema_c =
1880+ DFSchema :: new ( vec ! [ make_field( "tableC" , "cc" ) , make_field( "tableC" , "c" ) ] )
1881+ . unwrap ( ) ;
1882+ let schema_b = DFSchema :: new ( vec ! [ make_field( "tableB" , "b" ) ] ) . unwrap ( ) ;
1883+ // non matching
1884+ let schema_f =
1885+ DFSchema :: new ( vec ! [ make_field( "tableC" , "f" ) , make_field( "tableC" , "ff" ) ] )
1886+ . unwrap ( ) ;
1887+ let schemas = vec ! [ schema_c, schema_f, schema_b, schema_a]
1888+ . into_iter ( )
1889+ . map ( Arc :: new)
1890+ . collect :: < Vec < _ > > ( ) ;
1891+ let schemas = schemas. iter ( ) . collect :: < Vec < _ > > ( ) ;
1892+
1893+ let normalized_expr = normalize_col_with_schemas ( expr, & schemas, & [ ] ) . unwrap ( ) ;
1894+ assert_eq ! (
1895+ normalized_expr,
1896+ col( "tableA.a" ) + col( "tableB.b" ) + col( "tableC.c" )
1897+ ) ;
1898+ }
1899+
1900+ #[ test]
1901+ fn normalize_cols_priority ( ) {
1902+ let expr = col ( "a" ) + col ( "b" ) ;
1903+ // Schemas with multiple matches for column a, first takes priority
1904+ let schema_a = DFSchema :: new ( vec ! [ make_field( "tableA" , "a" ) ] ) . unwrap ( ) ;
1905+ let schema_b = DFSchema :: new ( vec ! [ make_field( "tableB" , "b" ) ] ) . unwrap ( ) ;
1906+ let schema_a2 = DFSchema :: new ( vec ! [ make_field( "tableA2" , "a" ) ] ) . unwrap ( ) ;
1907+ let schemas = vec ! [ schema_a2, schema_b, schema_a]
1908+ . into_iter ( )
1909+ . map ( Arc :: new)
1910+ . collect :: < Vec < _ > > ( ) ;
1911+ let schemas = schemas. iter ( ) . collect :: < Vec < _ > > ( ) ;
1912+
1913+ let normalized_expr = normalize_col_with_schemas ( expr, & schemas, & [ ] ) . unwrap ( ) ;
1914+ assert_eq ! ( normalized_expr, col( "tableA2.a" ) + col( "tableB.b" ) ) ;
1915+ }
1916+
1917+ #[ test]
1918+ fn normalize_cols_non_exist ( ) {
1919+ // test normalizing columns when the name doesn't exist
1920+ let expr = col ( "a" ) + col ( "b" ) ;
1921+ let schema_a = DFSchema :: new ( vec ! [ make_field( "tableA" , "a" ) ] ) . unwrap ( ) ;
1922+ let schemas = vec ! [ schema_a] . into_iter ( ) . map ( Arc :: new) . collect :: < Vec < _ > > ( ) ;
1923+ let schemas = schemas. iter ( ) . collect :: < Vec < _ > > ( ) ;
1924+
1925+ let error = normalize_col_with_schemas ( expr, & schemas, & [ ] )
1926+ . unwrap_err ( )
1927+ . to_string ( ) ;
1928+ assert_eq ! (
1929+ error,
1930+ "Error during planning: Column #b not found in provided schemas"
1931+ ) ;
1932+ }
1933+
1934+ #[ test]
1935+ fn unnormalize_cols ( ) {
1936+ let expr = col ( "tableA.a" ) + col ( "tableB.b" ) ;
1937+ let unnormalized_expr = unnormalize_col ( expr) ;
1938+ assert_eq ! ( unnormalized_expr, col( "a" ) + col( "b" ) ) ;
1939+ }
1940+
1941+ fn make_field ( relation : & str , column : & str ) -> DFField {
1942+ DFField :: new ( Some ( relation) , column, DataType :: Int8 , false )
1943+ }
18131944}
0 commit comments