Skip to content

Commit 33658e4

Browse files
maropuRobert Kruszewski
authored andcommitted
[SPARK-21351][SQL] Update nullability based on children's output
## What changes were proposed in this pull request? This pr added a new optimizer rule `UpdateNullabilityInAttributeReferences ` to update the nullability that `Filter` changes when having `IsNotNull`. In the master, optimized plans do not respect the nullability when `Filter` has `IsNotNull`. This wrongly generates unnecessary code. For example: ``` scala> val df = Seq((Some(1), Some(2))).toDF("a", "b") scala> val bIsNotNull = df.where($"b" =!= 2).select($"b") scala> val targetQuery = bIsNotNull.distinct scala> val targetQuery.queryExecution.optimizedPlan.output(0).nullable res5: Boolean = true scala> targetQuery.debugCodegen Found 2 WholeStageCodegen subtrees. == Subtree 1 / 2 == *HashAggregate(keys=[b#19], functions=[], output=[b#19]) +- Exchange hashpartitioning(b#19, 200) +- *HashAggregate(keys=[b#19], functions=[], output=[b#19]) +- *Project [_2#16 AS b#19] +- *Filter isnotnull(_2#16) +- LocalTableScan [_1#15, _2#16] Generated code: ... /* 124 */ protected void processNext() throws java.io.IOException { ... /* 132 */ // output the result /* 133 */ /* 134 */ while (agg_mapIter.next()) { /* 135 */ wholestagecodegen_numOutputRows.add(1); /* 136 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 137 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); /* 138 */ /* 139 */ boolean agg_isNull4 = agg_aggKey.isNullAt(0); /* 140 */ int agg_value4 = agg_isNull4 ? -1 : (agg_aggKey.getInt(0)); /* 141 */ agg_rowWriter1.zeroOutNullBytes(); /* 142 */ // We don't need this NULL check because NULL is filtered out in `$"b" =!=2` /* 143 */ if (agg_isNull4) { /* 144 */ agg_rowWriter1.setNullAt(0); /* 145 */ } else { /* 146 */ agg_rowWriter1.write(0, agg_value4); /* 147 */ } /* 148 */ append(agg_result1); /* 149 */ /* 150 */ if (shouldStop()) return; /* 151 */ } /* 152 */ /* 153 */ agg_mapIter.close(); /* 154 */ if (agg_sorter == null) { /* 155 */ agg_hashMap.free(); /* 156 */ } /* 157 */ } /* 158 */ /* 159 */ } ``` In the line 143, we don't need this NULL check because NULL is filtered out in `$"b" =!=2`. This pr could remove this NULL check; ``` scala> val targetQuery.queryExecution.optimizedPlan.output(0).nullable res5: Boolean = false scala> targetQuery.debugCodegen ... Generated code: ... /* 144 */ protected void processNext() throws java.io.IOException { ... /* 152 */ // output the result /* 153 */ /* 154 */ while (agg_mapIter.next()) { /* 155 */ wholestagecodegen_numOutputRows.add(1); /* 156 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 157 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); /* 158 */ /* 159 */ int agg_value4 = agg_aggKey.getInt(0); /* 160 */ agg_rowWriter1.write(0, agg_value4); /* 161 */ append(agg_result1); /* 162 */ /* 163 */ if (shouldStop()) return; /* 164 */ } /* 165 */ /* 166 */ agg_mapIter.close(); /* 167 */ if (agg_sorter == null) { /* 168 */ agg_hashMap.free(); /* 169 */ } /* 170 */ } ``` ## How was this patch tested? Added `UpdateNullabilityInAttributeReferencesSuite` for unit tests. Author: Takeshi Yamamuro <[email protected]> Closes apache#18576 from maropu/SPARK-21351.
1 parent d22e12e commit 33658e4

File tree

4 files changed

+75
-15
lines changed

4 files changed

+75
-15
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
153153
RewritePredicateSubquery,
154154
ColumnPruning,
155155
CollapseProject,
156-
RemoveRedundantProject)
156+
RemoveRedundantProject) :+
157+
Batch("UpdateAttributeReferences", Once,
158+
UpdateNullabilityInAttributeReferences)
157159
}
158160

159161
/**
@@ -1309,3 +1311,18 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
13091311
}
13101312
}
13111313
}
1314+
1315+
/**
1316+
* Updates nullability in [[AttributeReference]]s if nullability is different between
1317+
* non-leaf plan's expressions and the children output.
1318+
*/
1319+
object UpdateNullabilityInAttributeReferences extends Rule[LogicalPlan] {
1320+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
1321+
case p if !p.isInstanceOf[LeafNode] =>
1322+
val nullabilityMap = AttributeMap(p.children.flatMap(_.output).map { x => x -> x.nullable })
1323+
p transformExpressions {
1324+
case ar: AttributeReference if nullabilityMap.contains(ar) =>
1325+
ar.withNullability(nullabilityMap(ar))
1326+
}
1327+
}
1328+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.dsl.plans._
22+
import org.apache.spark.sql.catalyst.expressions.{CreateArray, GetArrayItem}
23+
import org.apache.spark.sql.catalyst.plans.PlanTest
24+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
25+
import org.apache.spark.sql.catalyst.rules.RuleExecutor
26+
27+
28+
class UpdateNullabilityInAttributeReferencesSuite extends PlanTest {
29+
30+
object Optimizer extends RuleExecutor[LogicalPlan] {
31+
val batches =
32+
Batch("Constant Folding", FixedPoint(10),
33+
NullPropagation,
34+
ConstantFolding,
35+
BooleanSimplification,
36+
SimplifyConditionals,
37+
SimplifyBinaryComparison,
38+
SimplifyExtractValueOps) ::
39+
Batch("UpdateAttributeReferences", Once,
40+
UpdateNullabilityInAttributeReferences) :: Nil
41+
}
42+
43+
test("update nullability in AttributeReference") {
44+
val rel = LocalRelation('a.long.notNull)
45+
// In the 'original' plans below, the Aggregate node produced by groupBy() has a
46+
// nullable AttributeReference to `b`, because both array indexing and map lookup are
47+
// nullable expressions. After optimization, the same attribute is now non-nullable,
48+
// but the AttributeReference is not updated to reflect this. So, we need to update nullability
49+
// by the `UpdateNullabilityInAttributeReferences` rule.
50+
val original = rel
51+
.select(GetArrayItem(CreateArray(Seq('a, 'a + 1L)), 0) as "b")
52+
.groupBy($"b")("1")
53+
val expected = rel.select('a as "b").groupBy($"b")("1").analyze
54+
val optimized = Optimizer.execute(original.analyze)
55+
comparePlans(optimized, expected)
56+
}
57+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -378,15 +378,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
378378
.groupBy($"foo")("1")
379379
checkRule(structRel, structExpected)
380380

381-
// These tests must use nullable attributes from the base relation for the following reason:
382-
// in the 'original' plans below, the Aggregate node produced by groupBy() has a
383-
// nullable AttributeReference to a1, because both array indexing and map lookup are
384-
// nullable expressions. After optimization, the same attribute is now non-nullable,
385-
// but the AttributeReference is not updated to reflect this. In the 'expected' plans,
386-
// the grouping expressions have the same nullability as the original attribute in the
387-
// relation. If that attribute is non-nullable, the tests will fail as the plans will
388-
// compare differently, so for these tests we must use a nullable attribute. See
389-
// SPARK-23634.
390381
val arrayRel = relation
391382
.select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1")
392383
.groupBy($"a1")("1")

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2056,11 +2056,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
20562056
expr: String,
20572057
expectedNonNullableColumns: Seq[String]): Unit = {
20582058
val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr)
2059-
// In the logical plan, all the output columns of input dataframe are nullable
2060-
dfWithFilter.queryExecution.optimizedPlan.collect {
2061-
case e: Filter => assert(e.output.forall(_.nullable))
2062-
}
2063-
20642059
dfWithFilter.queryExecution.executedPlan.collect {
20652060
// When the child expression in isnotnull is null-intolerant (i.e. any null input will
20662061
// result in null output), the involved columns are converted to not nullable;

0 commit comments

Comments
 (0)