Skip to content

Commit bc99025

Browse files
DonnyZonegatorsmile
authored andcommitted
[SPARK-19471][SQL] AggregationIterator does not initialize the generated result projection before using it
## What changes were proposed in this pull request? This is a follow-up PR that moves the test case in PR-18920 (apache#18920) to DataFrameAggregateSuit. ## How was this patch tested? unit test Author: donnyzone <[email protected]> Closes apache#18946 from DonnyZone/branch-19471-followingPR.
1 parent 12411b5 commit bc99025

File tree

2 files changed

+47
-45
lines changed

2 files changed

+47
-45
lines changed

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
package org.apache.spark.sql
1919

20+
import scala.util.Random
21+
22+
import org.apache.spark.sql.execution.WholeStageCodegenExec
23+
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
2024
import org.apache.spark.sql.expressions.Window
2125
import org.apache.spark.sql.functions._
2226
import org.apache.spark.sql.internal.SQLConf
@@ -558,6 +562,49 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
558562
assert(e.message.contains("aggregate functions are not allowed in GROUP BY"))
559563
}
560564

565+
private def assertNoExceptions(c: Column): Unit = {
566+
for ((wholeStage, useObjectHashAgg) <-
567+
Seq((true, true), (true, false), (false, true), (false, false))) {
568+
withSQLConf(
569+
(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString),
570+
(SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) {
571+
572+
val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y")
573+
574+
// test case for HashAggregate
575+
val hashAggDF = df.groupBy("x").agg(c, sum("y"))
576+
val hashAggPlan = hashAggDF.queryExecution.executedPlan
577+
if (wholeStage) {
578+
assert(hashAggPlan.find {
579+
case WholeStageCodegenExec(_: HashAggregateExec) => true
580+
case _ => false
581+
}.isDefined)
582+
} else {
583+
assert(hashAggPlan.isInstanceOf[HashAggregateExec])
584+
}
585+
hashAggDF.collect()
586+
587+
// test case for ObjectHashAggregate and SortAggregate
588+
val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y"))
589+
val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan
590+
if (useObjectHashAgg) {
591+
assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec])
592+
} else {
593+
assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec])
594+
}
595+
objHashAggOrSortAggDF.collect()
596+
}
597+
}
598+
}
599+
600+
test("SPARK-19471: AggregationIterator does not initialize the generated result projection" +
601+
" before using it") {
602+
Seq(
603+
monotonically_increasing_id(), spark_partition_id(),
604+
rand(Random.nextLong()), randn(Random.nextLong())
605+
).foreach(assertNoExceptions)
606+
}
607+
561608
test("SPARK-21580 ints in aggregation expressions are taken as group-by ordinal.") {
562609
checkAnswer(
563610
testData2.groupBy(lit(3), lit(4)).agg(lit(6), lit(7), sum("b")),

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ import scala.util.Random
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions.Expression
2626
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
27-
import org.apache.spark.sql.execution.WholeStageCodegenExec
28-
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
2927
import org.apache.spark.sql.functions._
3028
import org.apache.spark.sql.internal.SQLConf
3129
import org.apache.spark.sql.test.SharedSQLContext
@@ -451,49 +449,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
451449
).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
452450
}
453451

454-
private def assertNoExceptions(c: Column): Unit = {
455-
for ((wholeStage, useObjectHashAgg) <-
456-
Seq((true, true), (true, false), (false, true), (false, false))) {
457-
withSQLConf(
458-
(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString),
459-
(SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) {
460-
461-
val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y")
462-
463-
// HashAggregate test case
464-
val hashAggDF = df.groupBy("x").agg(c, sum("y"))
465-
val hashAggPlan = hashAggDF.queryExecution.executedPlan
466-
if (wholeStage) {
467-
assert(hashAggPlan.find {
468-
case WholeStageCodegenExec(_: HashAggregateExec) => true
469-
case _ => false
470-
}.isDefined)
471-
} else {
472-
assert(hashAggPlan.isInstanceOf[HashAggregateExec])
473-
}
474-
hashAggDF.collect()
475-
476-
// ObjectHashAggregate and SortAggregate test case
477-
val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y"))
478-
val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan
479-
if (useObjectHashAgg) {
480-
assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec])
481-
} else {
482-
assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec])
483-
}
484-
objHashAggOrSortAggDF.collect()
485-
}
486-
}
487-
}
488-
489-
test("SPARK-19471: AggregationIterator does not initialize the generated result projection" +
490-
" before using it") {
491-
Seq(
492-
monotonically_increasing_id(), spark_partition_id(),
493-
rand(Random.nextLong()), randn(Random.nextLong())
494-
).foreach(assertNoExceptions)
495-
}
496-
497452
test("SPARK-21281 use string types by default if array and map have no argument") {
498453
val ds = spark.range(1)
499454
var expectedSchema = new StructType()

0 commit comments

Comments
 (0)