Skip to content

Commit fbc2692

Browse files
DonnyZonegatorsmile
authored andcommitted
[SPARK-19471][SQL] AggregationIterator does not initialize the generated result projection before using it
## What changes were proposed in this pull request? Recently, we have also encountered such NPE issues in our production environment as described in: https://issues.apache.org/jira/browse/SPARK-19471 This issue can be reproduced by the following examples: ` val df = spark.createDataFrame(Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4))).toDF("x", "y") //HashAggregate, SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key=false df.groupBy("x").agg(rand(),sum("y")).show() //ObjectHashAggregate, SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key=false df.groupBy("x").agg(rand(),collect_list("y")).show() //SortAggregate, SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key=false &&SQLConf.USE_OBJECT_HASH_AGG.key=false df.groupBy("x").agg(rand(),collect_list("y")).show()` ` This PR is based on PR-16820(apache#16820) with test cases for all aggregation paths. We want to push it forward. > When AggregationIterator generates result projection, it does not call the initialize method of the Projection class. This will cause a runtime NullPointerException when the projection involves nondeterministic expressions. ## How was this patch tested? unit test verified in production environment Author: donnyzone <[email protected]> Closes apache#18920 from DonnyZone/Branch-spark-19471.
1 parent 0326b69 commit fbc2692

File tree

8 files changed

+63
-3
lines changed

8 files changed

+63
-3
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
3333
* is used to generate result.
3434
*/
3535
abstract class AggregationIterator(
36+
partIndex: Int,
3637
groupingExpressions: Seq[NamedExpression],
3738
inputAttributes: Seq[Attribute],
3839
aggregateExpressions: Seq[AggregateExpression],
@@ -217,6 +218,7 @@ abstract class AggregationIterator(
217218

218219
val resultProjection =
219220
UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes)
221+
resultProjection.initialize(partIndex)
220222

221223
(currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => {
222224
// Generate results for all expression-based aggregate functions.
@@ -235,6 +237,7 @@ abstract class AggregationIterator(
235237
val resultProjection = UnsafeProjection.create(
236238
groupingAttributes ++ bufferAttributes,
237239
groupingAttributes ++ bufferAttributes)
240+
resultProjection.initialize(partIndex)
238241

239242
// TypedImperativeAggregate stores generic object in aggregation buffer, and requires
240243
// calling serialization before shuffling. See [[TypedImperativeAggregate]] for more info.
@@ -256,6 +259,7 @@ abstract class AggregationIterator(
256259
} else {
257260
// Grouping-only: we only output values based on grouping expressions.
258261
val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
262+
resultProjection.initialize(partIndex)
259263
(currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => {
260264
resultProjection(currentGroupingKey)
261265
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ case class HashAggregateExec(
9696
val spillSize = longMetric("spillSize")
9797
val avgHashProbe = longMetric("avgHashProbe")
9898

99-
child.execute().mapPartitions { iter =>
99+
child.execute().mapPartitionsWithIndex { (partIndex, iter) =>
100100

101101
val hasInput = iter.hasNext
102102
if (!hasInput && groupingExpressions.nonEmpty) {
@@ -106,6 +106,7 @@ case class HashAggregateExec(
106106
} else {
107107
val aggregationIterator =
108108
new TungstenAggregationIterator(
109+
partIndex,
109110
groupingExpressions,
110111
aggregateExpressions,
111112
aggregateAttributes,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.unsafe.KVIterator
3131
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
3232

3333
class ObjectAggregationIterator(
34+
partIndex: Int,
3435
outputAttributes: Seq[Attribute],
3536
groupingExpressions: Seq[NamedExpression],
3637
aggregateExpressions: Seq[AggregateExpression],
@@ -43,6 +44,7 @@ class ObjectAggregationIterator(
4344
fallbackCountThreshold: Int,
4445
numOutputRows: SQLMetric)
4546
extends AggregationIterator(
47+
partIndex,
4648
groupingExpressions,
4749
originalInputAttributes,
4850
aggregateExpressions,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ case class ObjectHashAggregateExec(
9898
val numOutputRows = longMetric("numOutputRows")
9999
val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold
100100

101-
child.execute().mapPartitionsInternal { iter =>
101+
child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) =>
102102
val hasInput = iter.hasNext
103103
if (!hasInput && groupingExpressions.nonEmpty) {
104104
// This is a grouped aggregate and the input kvIterator is empty,
@@ -107,6 +107,7 @@ case class ObjectHashAggregateExec(
107107
} else {
108108
val aggregationIterator =
109109
new ObjectAggregationIterator(
110+
partIndex,
110111
child.output,
111112
groupingExpressions,
112113
aggregateExpressions,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ case class SortAggregateExec(
7474

7575
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
7676
val numOutputRows = longMetric("numOutputRows")
77-
child.execute().mapPartitionsInternal { iter =>
77+
child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) =>
7878
// Because the constructor of an aggregation iterator will read at least the first row,
7979
// we need to get the value of iter.hasNext first.
8080
val hasInput = iter.hasNext
@@ -84,6 +84,7 @@ case class SortAggregateExec(
8484
Iterator[UnsafeRow]()
8585
} else {
8686
val outputIter = new SortBasedAggregationIterator(
87+
partIndex,
8788
groupingExpressions,
8889
child.output,
8990
iter,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
2727
* sorted by values of [[groupingExpressions]].
2828
*/
2929
class SortBasedAggregationIterator(
30+
partIndex: Int,
3031
groupingExpressions: Seq[NamedExpression],
3132
valueAttributes: Seq[Attribute],
3233
inputIterator: Iterator[InternalRow],
@@ -37,6 +38,7 @@ class SortBasedAggregationIterator(
3738
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
3839
numOutputRows: SQLMetric)
3940
extends AggregationIterator(
41+
partIndex,
4042
groupingExpressions,
4143
valueAttributes,
4244
aggregateExpressions,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ import org.apache.spark.unsafe.KVIterator
6060
* - Part 8: A utility function used to generate a result when there is no
6161
* input and there is no grouping expression.
6262
*
63+
* @param partIndex
64+
* index of the partition
6365
* @param groupingExpressions
6466
* expressions for grouping keys
6567
* @param aggregateExpressions
@@ -77,6 +79,7 @@ import org.apache.spark.unsafe.KVIterator
7779
* the iterator containing input [[UnsafeRow]]s.
7880
*/
7981
class TungstenAggregationIterator(
82+
partIndex: Int,
8083
groupingExpressions: Seq[NamedExpression],
8184
aggregateExpressions: Seq[AggregateExpression],
8285
aggregateAttributes: Seq[Attribute],
@@ -91,6 +94,7 @@ class TungstenAggregationIterator(
9194
spillSize: SQLMetric,
9295
avgHashProbe: SQLMetric)
9396
extends AggregationIterator(
97+
partIndex,
9498
groupingExpressions,
9599
originalInputAttributes,
96100
aggregateExpressions,

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import scala.util.Random
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions.Expression
2626
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
27+
import org.apache.spark.sql.execution.WholeStageCodegenExec
28+
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
2729
import org.apache.spark.sql.functions._
2830
import org.apache.spark.sql.internal.SQLConf
2931
import org.apache.spark.sql.test.SharedSQLContext
@@ -449,6 +451,49 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
449451
).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
450452
}
451453

454+
private def assertNoExceptions(c: Column): Unit = {
455+
for ((wholeStage, useObjectHashAgg) <-
456+
Seq((true, true), (true, false), (false, true), (false, false))) {
457+
withSQLConf(
458+
(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString),
459+
(SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) {
460+
461+
val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y")
462+
463+
// HashAggregate test case
464+
val hashAggDF = df.groupBy("x").agg(c, sum("y"))
465+
val hashAggPlan = hashAggDF.queryExecution.executedPlan
466+
if (wholeStage) {
467+
assert(hashAggPlan.find {
468+
case WholeStageCodegenExec(_: HashAggregateExec) => true
469+
case _ => false
470+
}.isDefined)
471+
} else {
472+
assert(hashAggPlan.isInstanceOf[HashAggregateExec])
473+
}
474+
hashAggDF.collect()
475+
476+
// ObjectHashAggregate and SortAggregate test case
477+
val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y"))
478+
val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan
479+
if (useObjectHashAgg) {
480+
assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec])
481+
} else {
482+
assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec])
483+
}
484+
objHashAggOrSortAggDF.collect()
485+
}
486+
}
487+
}
488+
489+
test("SPARK-19471: AggregationIterator does not initialize the generated result projection" +
490+
" before using it") {
491+
Seq(
492+
monotonically_increasing_id(), spark_partition_id(),
493+
rand(Random.nextLong()), randn(Random.nextLong())
494+
).foreach(assertNoExceptions)
495+
}
496+
452497
test("SPARK-21281 use string types by default if array and map have no argument") {
453498
val ds = spark.range(1)
454499
var expectedSchema = new StructType()

0 commit comments

Comments
 (0)