Skip to content

Commit 81e19be

Browse files
authored
chore: Add unit tests for CometExecRule (#2863)
1 parent 0fec0f5 commit 81e19be

File tree

5 files changed

+265
-1
lines changed

5 files changed

+265
-1
lines changed

.github/workflows/pr_build_linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ jobs:
141141
org.apache.spark.CometPluginsDefaultSuite
142142
org.apache.spark.CometPluginsNonOverrideSuite
143143
org.apache.spark.CometPluginsUnifiedModeOverrideSuite
144+
org.apache.comet.rules.CometExecRuleSuite
144145
org.apache.spark.sql.CometTPCDSQuerySuite
145146
org.apache.spark.sql.CometTPCDSQueryTestSuite
146147
org.apache.spark.sql.CometTPCHQuerySuite

.github/workflows/pr_build_macos.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ jobs:
106106
org.apache.spark.CometPluginsDefaultSuite
107107
org.apache.spark.CometPluginsNonOverrideSuite
108108
org.apache.spark.CometPluginsUnifiedModeOverrideSuite
109+
org.apache.comet.rules.CometExecRuleSuite
109110
org.apache.spark.sql.CometTPCDSQuerySuite
110111
org.apache.spark.sql.CometTPCDSQueryTestSuite
111112
org.apache.spark.sql.CometTPCHQuerySuite

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,22 @@ object CometConf extends ShimCometConf {
654654
.booleanConf
655655
.createWithDefault(COMET_SCHEMA_EVOLUTION_ENABLED_DEFAULT)
656656

657+
val COMET_ENABLE_PARTIAL_HASH_AGGREGATE: ConfigEntry[Boolean] =
658+
conf("spark.comet.testing.aggregate.partialMode.enabled")
659+
.internal()
660+
.category(CATEGORY_TESTING)
661+
.doc("This setting is used in unit tests")
662+
.booleanConf
663+
.createWithDefault(true)
664+
665+
val COMET_ENABLE_FINAL_HASH_AGGREGATE: ConfigEntry[Boolean] =
666+
conf("spark.comet.testing.aggregate.finalMode.enabled")
667+
.internal()
668+
.category(CATEGORY_TESTING)
669+
.doc("This setting is used in unit tests")
670+
.booleanConf
671+
.createWithDefault(true)
672+
657673
val COMET_SPARK_TO_ARROW_ENABLED: ConfigEntry[Boolean] =
658674
conf("spark.comet.sparkToColumnar.enabled")
659675
.category(CATEGORY_TESTING)

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.broadcast.Broadcast
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.sql.catalyst.InternalRow
3333
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, ExpressionSet, Generator, NamedExpression, SortOrder}
34-
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, Final, Partial}
34+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, Final, Partial, PartialMerge}
3535
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
3636
import org.apache.spark.sql.catalyst.plans._
3737
import org.apache.spark.sql.catalyst.plans.physical._
@@ -1234,6 +1234,20 @@ object CometHashAggregateExec
12341234
override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
12351235
CometConf.COMET_EXEC_AGGREGATE_ENABLED)
12361236

1237+
override def getSupportLevel(op: HashAggregateExec): SupportLevel = {
1238+
// some unit tests need to disable partial or final hash aggregate support to test that
1239+
// CometExecRule does not allow mixed Spark/Comet aggregates
1240+
if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) &&
1241+
op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) {
1242+
return Unsupported(Some("Partial aggregates disabled via test config"))
1243+
}
1244+
if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) &&
1245+
op.aggregateExpressions.exists(_.mode == Final)) {
1246+
return Unsupported(Some("Final aggregates disabled via test config"))
1247+
}
1248+
Compatible()
1249+
}
1250+
12371251
override def convert(
12381252
aggregate: HashAggregateExec,
12391253
builder: Operator.Builder,
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.rules
21+
22+
import scala.util.Random
23+
24+
import org.apache.spark.sql._
25+
import org.apache.spark.sql.comet._
26+
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
27+
import org.apache.spark.sql.execution._
28+
import org.apache.spark.sql.execution.adaptive.QueryStageExec
29+
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
30+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
31+
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
32+
33+
import org.apache.comet.CometConf
34+
import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator}
35+
36+
/**
37+
* Test suite specifically for CometExecRule transformation logic. Tests the rule's ability to
38+
* transform Spark operators to Comet operators, fallback mechanisms, configuration handling, and
39+
* edge cases.
40+
*/
41+
class CometExecRuleSuite extends CometTestBase {
42+
43+
/** Helper method to apply CometExecRule and return the transformed plan */
44+
private def applyCometExecRule(plan: SparkPlan): SparkPlan = {
45+
CometExecRule(spark).apply(stripAQEPlan(plan))
46+
}
47+
48+
/** Create a test data frame that is used in all tests */
49+
private def createTestDataFrame = {
50+
val testSchema = new StructType(
51+
Array(
52+
StructField("id", DataTypes.IntegerType, nullable = true),
53+
StructField("name", DataTypes.StringType, nullable = true)))
54+
FuzzDataGenerator.generateDataFrame(new Random(42), spark, testSchema, 100, DataGenOptions())
55+
}
56+
57+
/** Create a SparkPlan from the specified SQL with Comet disabled */
58+
private def createSparkPlan(spark: SparkSession, sql: String): SparkPlan = {
59+
var sparkPlan: SparkPlan = null
60+
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
61+
val df = spark.sql(sql)
62+
sparkPlan = df.queryExecution.executedPlan
63+
}
64+
sparkPlan
65+
}
66+
67+
/** Count the number of the specified operator in the plan */
68+
private def countOperators(plan: SparkPlan, opClass: Class[_]): Int = {
69+
stripAQEPlan(plan).collect {
70+
case stage: QueryStageExec =>
71+
countOperators(stage.plan, opClass)
72+
case op if op.getClass.isAssignableFrom(opClass) => 1
73+
}.sum
74+
}
75+
76+
test(
77+
"CometExecRule should apply basic operator transformations, but only when Comet is enabled") {
78+
withTempView("test_data") {
79+
createTestDataFrame.createOrReplaceTempView("test_data")
80+
81+
val sparkPlan =
82+
createSparkPlan(spark, "SELECT id, id * 2 as doubled FROM test_data WHERE id % 2 == 0")
83+
84+
// Count original Spark operators
85+
assert(countOperators(sparkPlan, classOf[ProjectExec]) == 1)
86+
assert(countOperators(sparkPlan, classOf[FilterExec]) == 1)
87+
88+
for (cometEnabled <- Seq(true, false)) {
89+
withSQLConf(
90+
CometConf.COMET_ENABLED.key -> cometEnabled.toString,
91+
CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
92+
93+
val transformedPlan = applyCometExecRule(sparkPlan)
94+
95+
if (cometEnabled) {
96+
assert(countOperators(transformedPlan, classOf[ProjectExec]) == 0)
97+
assert(countOperators(transformedPlan, classOf[FilterExec]) == 0)
98+
assert(countOperators(transformedPlan, classOf[CometProjectExec]) == 1)
99+
assert(countOperators(transformedPlan, classOf[CometFilterExec]) == 1)
100+
} else {
101+
assert(countOperators(transformedPlan, classOf[ProjectExec]) == 1)
102+
assert(countOperators(transformedPlan, classOf[FilterExec]) == 1)
103+
assert(countOperators(transformedPlan, classOf[CometProjectExec]) == 0)
104+
assert(countOperators(transformedPlan, classOf[CometFilterExec]) == 0)
105+
}
106+
}
107+
}
108+
}
109+
}
110+
111+
test("CometExecRule should apply hash aggregate transformations") {
112+
withTempView("test_data") {
113+
createTestDataFrame.createOrReplaceTempView("test_data")
114+
115+
val sparkPlan =
116+
createSparkPlan(spark, "SELECT COUNT(*), SUM(id) FROM test_data GROUP BY (id % 3)")
117+
118+
// Count original Spark operators
119+
val originalHashAggCount = countOperators(sparkPlan, classOf[HashAggregateExec])
120+
assert(originalHashAggCount == 2)
121+
122+
withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
123+
val transformedPlan = applyCometExecRule(sparkPlan)
124+
125+
assert(countOperators(transformedPlan, classOf[HashAggregateExec]) == 0)
126+
assert(
127+
countOperators(
128+
transformedPlan,
129+
classOf[CometHashAggregateExec]) == originalHashAggCount)
130+
}
131+
}
132+
}
133+
134+
// TODO this test exposes the bug described in
135+
// https://github.com/apache/datafusion-comet/issues/1389
136+
ignore("CometExecRule should not allow Comet partial and Spark final hash aggregate") {
137+
withTempView("test_data") {
138+
createTestDataFrame.createOrReplaceTempView("test_data")
139+
140+
val sparkPlan =
141+
createSparkPlan(spark, "SELECT COUNT(*), SUM(id) FROM test_data GROUP BY (id % 3)")
142+
143+
// Count original Spark operators
144+
val originalHashAggCount = countOperators(sparkPlan, classOf[HashAggregateExec])
145+
assert(originalHashAggCount == 2)
146+
147+
withSQLConf(
148+
CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> "false",
149+
CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
150+
val transformedPlan = applyCometExecRule(sparkPlan)
151+
152+
// if the final aggregate cannot be converted to Comet, then neither should be
153+
assert(
154+
countOperators(transformedPlan, classOf[HashAggregateExec]) == originalHashAggCount)
155+
assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0)
156+
}
157+
}
158+
}
159+
160+
test("CometExecRule should not allow Spark partial and Comet final hash aggregate") {
161+
withTempView("test_data") {
162+
createTestDataFrame.createOrReplaceTempView("test_data")
163+
164+
val sparkPlan =
165+
createSparkPlan(spark, "SELECT COUNT(*), SUM(id) FROM test_data GROUP BY (id % 3)")
166+
167+
// Count original Spark operators
168+
val originalHashAggCount = countOperators(sparkPlan, classOf[HashAggregateExec])
169+
assert(originalHashAggCount == 2)
170+
171+
withSQLConf(
172+
CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false",
173+
CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
174+
val transformedPlan = applyCometExecRule(sparkPlan)
175+
176+
// if the partial aggregate cannot be converted to Comet, then neither should be
177+
assert(
178+
countOperators(transformedPlan, classOf[HashAggregateExec]) == originalHashAggCount)
179+
assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0)
180+
}
181+
}
182+
}
183+
184+
test("CometExecRule should apply broadcast exchange transformations") {
185+
withTempView("test_data") {
186+
createTestDataFrame.createOrReplaceTempView("test_data")
187+
188+
val sparkPlan = createSparkPlan(
189+
spark,
190+
"SELECT /*+ BROADCAST(b) */ a.id, b.name FROM test_data a JOIN test_data b ON a.id = b.id")
191+
192+
// Count original Spark operators
193+
val originalBroadcastExchangeCount =
194+
countOperators(sparkPlan, classOf[BroadcastExchangeExec])
195+
assert(originalBroadcastExchangeCount == 1)
196+
197+
withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
198+
val transformedPlan = applyCometExecRule(sparkPlan)
199+
200+
assert(countOperators(transformedPlan, classOf[BroadcastExchangeExec]) == 0)
201+
assert(
202+
countOperators(
203+
transformedPlan,
204+
classOf[CometBroadcastExchangeExec]) == originalBroadcastExchangeCount)
205+
}
206+
}
207+
}
208+
209+
test("CometExecRule should apply shuffle exchange transformations") {
210+
withTempView("test_data") {
211+
createTestDataFrame.createOrReplaceTempView("test_data")
212+
213+
val sparkPlan =
214+
createSparkPlan(spark, "SELECT id, COUNT(*) FROM test_data GROUP BY id ORDER BY id")
215+
216+
// Count original Spark operators
217+
val originalShuffleExchangeCount = countOperators(sparkPlan, classOf[ShuffleExchangeExec])
218+
assert(originalShuffleExchangeCount == 2)
219+
220+
withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
221+
val transformedPlan = applyCometExecRule(sparkPlan)
222+
223+
assert(countOperators(transformedPlan, classOf[ShuffleExchangeExec]) == 0)
224+
assert(
225+
countOperators(
226+
transformedPlan,
227+
classOf[CometShuffleExchangeExec]) == originalShuffleExchangeCount)
228+
}
229+
}
230+
}
231+
232+
}

0 commit comments

Comments
 (0)