Skip to content

Commit 98a9ca2

Browse files
authored
chore: Enable Comet explicitly in CometTPCDSQueryTestSuite (#1559)
Enable Comet support explicitly rather than rely on default setting.
1 parent 5224108 commit 98a9ca2

File tree

3 files changed

+31
-15
lines changed

3 files changed

+31
-15
lines changed

spark/src/test/scala/org/apache/spark/sql/CometTPCDSQueryTestSuite.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import org.apache.spark.sql.catalyst.util.{fileToString, resourceToString, strin
2929
import org.apache.spark.sql.internal.SQLConf
3030
import org.apache.spark.sql.test.TestSparkSession
3131

32+
import org.apache.comet.CometConf
33+
3234
/**
3335
* Because we need to modify some methods of Spark `TPCDSQueryTestSuite` but they are private, we
3436
* copy Spark `TPCDSQueryTestSuite`.
@@ -164,12 +166,18 @@ class CometTPCDSQueryTestSuite extends QueryTest with TPCDSBase with CometSQLQue
164166
}
165167
}
166168

169+
val baseConf: Map[String, String] = Map(
170+
CometConf.COMET_ENABLED.key -> "true",
171+
CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true",
172+
CometConf.COMET_EXEC_ENABLED.key -> "true",
173+
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true")
174+
167175
val sortMergeJoinConf: Map[String, String] = Map(
168176
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
169177
SQLConf.PREFER_SORTMERGEJOIN.key -> "true")
170178

171-
val broadcastHashJoinConf: Map[String, String] = Map(
172-
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760")
179+
val broadcastHashJoinConf: Map[String, String] =
180+
Map(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760")
173181

174182
val shuffledHashJoinConf: Map[String, String] = Map(
175183
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
@@ -213,7 +221,7 @@ class CometTPCDSQueryTestSuite extends QueryTest with TPCDSBase with CometSQLQue
213221
// that can cause OOM in GitHub Actions
214222
if (!(sortMergeJoin && name == "q72")) {
215223
System.gc() // Workaround for GitHub Actions memory limitation, see also SPARK-37368
216-
runQuery(queryString, goldenFile, conf)
224+
runQuery(queryString, goldenFile, baseConf ++ conf)
217225
}
218226
}
219227
}
@@ -231,7 +239,7 @@ class CometTPCDSQueryTestSuite extends QueryTest with TPCDSBase with CometSQLQue
231239
// that can cause OOM in GitHub Actions
232240
if (!(sortMergeJoin && name == "q72")) {
233241
System.gc() // SPARK-37368
234-
runQuery(queryString, goldenFile, conf)
242+
runQuery(queryString, goldenFile, baseConf ++ conf)
235243
}
236244
}
237245
}
@@ -245,7 +253,7 @@ class CometTPCDSQueryTestSuite extends QueryTest with TPCDSBase with CometSQLQue
245253
val goldenFile = new File(s"$baseResourcePath/extended", s"$name.sql.out")
246254
joinConfs.foreach { conf =>
247255
System.gc() // SPARK-37368
248-
runQuery(queryString, goldenFile, conf)
256+
runQuery(queryString, goldenFile, baseConf ++ conf)
249257
}
250258
}
251259
}

spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,8 @@ class CometTPCHQuerySuite extends QueryTest with TPCBase with ShimCometTPCHQuery
9090
conf.set(CometConf.COMET_EXEC_ENABLED.key, "true")
9191
conf.set(CometConf.COMET_NATIVE_SCAN_ENABLED.key, "true")
9292
conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true")
93-
conf.set(CometConf.COMET_SHUFFLE_MODE.key, "jvm")
9493
conf.set(MEMORY_OFFHEAP_ENABLED.key, "true")
9594
conf.set(MEMORY_OFFHEAP_SIZE.key, "2g")
96-
conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g")
9795
}
9896

9997
protected override def createSparkSession: TestSparkSession = {

spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.scalatest.Tag
3030
import org.scalatestplus.junit.JUnitRunner
3131

3232
import org.apache.spark.{DebugFilesystem, SparkConf}
33-
import org.apache.spark.sql.{QueryTest, SparkSession, SQLContext}
33+
import org.apache.spark.sql.{CometTestBase, SparkSession, SQLContext}
3434
import org.apache.spark.sql.internal.SQLConf
3535
import org.apache.spark.sql.test.SQLTestUtils
3636

@@ -41,7 +41,7 @@ import org.apache.comet.{CometConf, CometSparkSessionExtensions, IntegrationTest
4141
*/
4242
@RunWith(classOf[JUnitRunner])
4343
@IntegrationTestSuite
44-
class ParquetEncryptionITCase extends QueryTest with SQLTestUtils {
44+
class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils {
4545
private val encoder = Base64.getEncoder
4646
private val footerKey =
4747
encoder.encodeToString("0123456789012345".getBytes(StandardCharsets.UTF_8))
@@ -81,7 +81,12 @@ class ParquetEncryptionITCase extends QueryTest with SQLTestUtils {
8181
val parquetDF = spark.read.parquet(parquetDir)
8282
assert(parquetDF.inputFiles.nonEmpty)
8383
val readDataset = parquetDF.select("a", "b", "c")
84-
checkAnswer(readDataset, inputDF)
84+
85+
if (CometConf.COMET_ENABLED.get(conf)) {
86+
checkSparkAnswerAndOperator(readDataset)
87+
} else {
88+
checkAnswer(readDataset, inputDF)
89+
}
8590
}
8691
}
8792
}
@@ -118,19 +123,24 @@ class ParquetEncryptionITCase extends QueryTest with SQLTestUtils {
118123
val parquetDF = spark.read.parquet(parquetDir)
119124
assert(parquetDF.inputFiles.nonEmpty)
120125
val readDataset = parquetDF.select("a", "b", "c")
121-
checkAnswer(readDataset, inputDF)
126+
127+
if (CometConf.COMET_ENABLED.get(conf)) {
128+
checkSparkAnswerAndOperator(readDataset)
129+
} else {
130+
checkAnswer(readDataset, inputDF)
131+
}
122132
}
123133
}
124134
}
125135
}
126136

127-
protected def sparkConf: SparkConf = {
137+
protected override def sparkConf: SparkConf = {
128138
val conf = new SparkConf()
129139
conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)
130140
conf
131141
}
132142

133-
protected def createSparkSession: SparkSession = {
143+
protected override def createSparkSession: SparkSession = {
134144
SparkSession
135145
.builder()
136146
.config(sparkConf)
@@ -159,8 +169,8 @@ class ParquetEncryptionITCase extends QueryTest with SQLTestUtils {
159169
}
160170

161171
private var _spark: SparkSession = _
162-
protected implicit def spark: SparkSession = _spark
163-
protected implicit def sqlContext: SQLContext = _spark.sqlContext
172+
protected implicit override def spark: SparkSession = _spark
173+
protected implicit override def sqlContext: SQLContext = _spark.sqlContext
164174

165175
/**
166176
* Verify that the directory contains an encrypted parquet in encrypted footer mode by means of

0 commit comments

Comments
 (0)