Skip to content

Commit 2d22b51

Browse files
committed
Pt2
Signed-off-by: Andy HF Kwok <[email protected]>
1 parent a1b74cc commit 2d22b51

13 files changed

+56
-51
lines changed

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1162,7 +1162,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11621162
def makeDecimalRDD(num: Int, decimal: DecimalType, useDictionary: Boolean): DataFrame = {
11631163
val div = if (useDictionary) 5 else num // narrow the space to make it dictionary encoded
11641164
spark
1165-
.range(num)
1165+
.range(num.toLong)
11661166
.map(_ % div)
11671167
// Parquet doesn't allow column names with spaces, have to add an alias here.
11681168
// Minus 500 here so that negative decimals are also tested.

spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,7 +1801,7 @@ class CometExecSuite extends CometTestBase {
18011801
withTable("t1") {
18021802
val numRows = 10
18031803
spark
1804-
.range(numRows)
1804+
.range(numRows.toLong)
18051805
.selectExpr("if (id % 2 = 0, null, id) AS a", s"$numRows - id AS b")
18061806
.repartition(3) // Force repartition to test data will come to single partition
18071807
.write
@@ -1838,7 +1838,7 @@ class CometExecSuite extends CometTestBase {
18381838
withTable("t1") {
18391839
val numRows = 10
18401840
spark
1841-
.range(numRows)
1841+
.range(numRows.toLong)
18421842
.selectExpr("if (id % 2 = 0, null, id) AS a", s"$numRows - id AS b")
18431843
.repartition(3) // Force repartition to test data will come to single partition
18441844
.write
@@ -1869,7 +1869,7 @@ class CometExecSuite extends CometTestBase {
18691869
withTable("t1") {
18701870
val numRows = 10
18711871
spark
1872-
.range(numRows)
1872+
.range(numRows.toLong)
18731873
.selectExpr("if (id % 2 = 0, null, id) AS a", s"$numRows - id AS b")
18741874
.repartition(3) // Force repartition to test data will come to single partition
18751875
.write

spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -416,15 +416,15 @@ abstract class ParquetReadSuite extends CometTestBase {
416416
opt match {
417417
case Some(i) =>
418418
record.add(0, i % 2 == 0)
419-
record.add(1, i.toByte)
420-
record.add(2, i.toShort)
419+
record.add(1, i.toByte.toInt)
420+
record.add(2, i.toShort.toInt)
421421
record.add(3, i)
422422
record.add(4, i.toLong)
423423
record.add(5, i.toFloat)
424424
record.add(6, i.toDouble)
425425
record.add(7, i.toString * 48)
426-
record.add(8, (-i).toByte)
427-
record.add(9, (-i).toShort)
426+
record.add(8, (-i).toByte.toInt)
427+
record.add(9, (-i).toShort.toInt)
428428
record.add(10, -i)
429429
record.add(11, (-i).toLong)
430430
record.add(12, i.toString)
@@ -639,8 +639,8 @@ abstract class ParquetReadSuite extends CometTestBase {
639639
opt match {
640640
case Some(i) =>
641641
record.add(0, i % 2 == 0)
642-
record.add(1, i.toByte)
643-
record.add(2, i.toShort)
642+
record.add(1, i.toByte.toInt)
643+
record.add(2, i.toShort.toInt)
644644
record.add(3, i)
645645
record.add(4, i.toLong)
646646
record.add(5, i.toFloat)
@@ -1575,15 +1575,15 @@ abstract class ParquetReadSuite extends CometTestBase {
15751575
opt match {
15761576
case Some(i) =>
15771577
record.add(0, i % 2 == 0)
1578-
record.add(1, i.toByte)
1579-
record.add(2, i.toShort)
1578+
record.add(1, i.toByte.toInt)
1579+
record.add(2, i.toShort.toInt)
15801580
record.add(3, i)
15811581
record.add(4, i.toLong)
15821582
record.add(5, i.toFloat)
15831583
record.add(6, i.toDouble)
15841584
record.add(7, i.toString * 48)
1585-
record.add(8, (-i).toByte)
1586-
record.add(9, (-i).toShort)
1585+
record.add(8, (-i).toByte.toInt)
1586+
record.add(9, (-i).toShort.toInt)
15871587
record.add(10, -i)
15881588
record.add(11, (-i).toLong)
15891589
record.add(12, i.toString)
@@ -1672,7 +1672,7 @@ abstract class ParquetReadSuite extends CometTestBase {
16721672
val record = new SimpleGroup(schema)
16731673
opt match {
16741674
case Some(i) =>
1675-
record.add(0, i.toShort)
1675+
record.add(0, i.toShort.toInt)
16761676
record.add(1, i)
16771677
record.add(2, i.toLong)
16781678
case _ =>
@@ -1765,7 +1765,7 @@ abstract class ParquetReadSuite extends CometTestBase {
17651765
}
17661766

17671767
private def withId(id: Int) =
1768-
new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id).build()
1768+
new MetadataBuilder().putLong(ParquetUtils.FIELD_ID_METADATA_KEY, id.toLong).build()
17691769

17701770
// Based on Spark ParquetIOSuite.test("vectorized reader: array of nested struct")
17711771
test("array of nested struct with and without field id") {

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -557,15 +557,15 @@ abstract class CometTestBase
557557
opt match {
558558
case Some(i) =>
559559
record.add(0, i % 2 == 0)
560-
record.add(1, i.toByte)
561-
record.add(2, i.toShort)
560+
record.add(1, i.toByte.toInt)
561+
record.add(2, i.toShort.toInt)
562562
record.add(3, i)
563563
record.add(4, i.toLong)
564564
record.add(5, i.toFloat)
565565
record.add(6, i.toDouble)
566566
record.add(7, i.toString * 48)
567-
record.add(8, (-i).toByte)
568-
record.add(9, (-i).toShort)
567+
record.add(8, (-i).toByte.toInt)
568+
record.add(9, (-i).toShort.toInt)
569569
record.add(10, -i)
570570
record.add(11, (-i).toLong)
571571
record.add(12, i.toString)
@@ -586,15 +586,15 @@ abstract class CometTestBase
586586
val i = rand.nextLong()
587587
val record = new SimpleGroup(schema)
588588
record.add(0, i % 2 == 0)
589-
record.add(1, i.toByte)
590-
record.add(2, i.toShort)
589+
record.add(1, i.toByte.toInt)
590+
record.add(2, i.toShort.toInt)
591591
record.add(3, i.toInt)
592592
record.add(4, i)
593593
record.add(5, java.lang.Float.intBitsToFloat(i.toInt))
594594
record.add(6, java.lang.Double.longBitsToDouble(i))
595595
record.add(7, i.toString * 24)
596-
record.add(8, (-i).toByte)
597-
record.add(9, (-i).toShort)
596+
record.add(8, (-i).toByte.toInt)
597+
record.add(9, (-i).toShort.toInt)
598598
record.add(10, (-i).toInt)
599599
record.add(11, -i)
600600
record.add(12, i.toString)
@@ -643,7 +643,7 @@ abstract class CometTestBase
643643
if (rand.nextBoolean()) {
644644
None
645645
} else {
646-
Some(getValue(i, div))
646+
Some(getValue(i.toLong, div.toLong))
647647
}
648648
}
649649
expected.foreach { opt =>
@@ -697,7 +697,7 @@ abstract class CometTestBase
697697
if (rand.nextBoolean()) {
698698
None
699699
} else {
700-
Some(getValue(i, div))
700+
Some(getValue(i.toLong, div.toLong))
701701
}
702702
}
703703
expected.foreach { opt =>
@@ -875,7 +875,7 @@ abstract class CometTestBase
875875
val div = if (dictionaryEnabled) 10 else n // maps value to a small range for dict to kick in
876876

877877
val expected = (0 until n).map { i =>
878-
Some(getValue(i, div))
878+
Some(getValue(i.toLong, div.toLong))
879879
}
880880
expected.foreach { opt =>
881881
val timestampFormats = List(
@@ -923,7 +923,7 @@ abstract class CometTestBase
923923
def makeDecimalRDD(num: Int, decimal: DecimalType, useDictionary: Boolean): DataFrame = {
924924
val div = if (useDictionary) 5 else num // narrow the space to make it dictionary encoded
925925
spark
926-
.range(num)
926+
.range(num.toLong)
927927
.map(_ % div)
928928
// Parquet doesn't allow column names with spaces, have to add an alias here.
929929
// Minus 500 here so that negative decimals are also tested.
@@ -1103,8 +1103,8 @@ abstract class CometTestBase
11031103
val record = new SimpleGroup(schema)
11041104
opt match {
11051105
case Some(i) =>
1106-
record.add(0, i.toByte)
1107-
record.add(1, i.toShort)
1106+
record.add(0, i.toByte.toInt)
1107+
record.add(1, i.toShort.toInt)
11081108
record.add(2, i)
11091109
record.add(3, i.toLong)
11101110
record.add(4, rand.nextFloat())

spark/src/test/scala/org/apache/spark/sql/GenTPCHData.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ object GenTPCHData {
6565
// Install the data generators in all nodes
6666
// TODO: think a better way to install on each worker node
6767
// such as https://stackoverflow.com/a/40876671
68-
spark.range(0, workers, 1, workers).foreach(worker => installDBGEN(baseDir)(worker))
68+
spark
69+
.range(0L, workers.toLong, 1L, workers)
70+
.foreach(worker => installDBGEN(baseDir)(worker))
6971
s"${baseDir}/dbgen"
7072
} else {
7173
config.dbgenDir
@@ -91,7 +93,7 @@ object GenTPCHData {
9193

9294
// Clean up
9395
if (defaultDbgenDir != null) {
94-
spark.range(0, workers, 1, workers).foreach { _ =>
96+
spark.range(0L, workers.toLong, 1L, workers).foreach { _ =>
9597
val _ = FileUtils.deleteQuietly(defaultDbgenDir)
9698
}
9799
}

spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
6666
new Benchmark(
6767
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " +
6868
s"single aggregate ${aggregateFunction.toString}",
69-
values,
69+
values.toLong,
7070
output = output)
7171

7272
withTempPath { dir =>
@@ -104,7 +104,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
104104
new Benchmark(
105105
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " +
106106
s"single aggregate ${aggregateFunction.toString} on decimal",
107-
values,
107+
values.toLong,
108108
output = output)
109109

110110
val df = makeDecimalDataFrame(values, dataType, false);
@@ -145,7 +145,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
145145
new Benchmark(
146146
s"Grouped HashAgg Exec: multiple group keys (cardinality $groupingKeyCard), " +
147147
s"single aggregate ${aggregateFunction.toString}",
148-
values,
148+
values.toLong,
149149
output = output)
150150

151151
withTempPath { dir =>
@@ -186,7 +186,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
186186
new Benchmark(
187187
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCard), " +
188188
s"multiple aggregates ${aggregateFunction.toString}",
189-
values,
189+
values.toLong,
190190
output = output)
191191

192192
withTempPath { dir =>

spark/src/test/scala/org/apache/spark/sql/benchmark/CometArithmeticBenchmark.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ object CometArithmeticBenchmark extends CometBenchmarkBase {
3737
val dataType = IntegerType
3838
val benchmark = new Benchmark(
3939
s"Binary op ${dataType.sql}, dictionary = $useDictionary",
40-
values,
40+
values.toLong,
4141
output = output)
4242

4343
withTempPath { dir =>
@@ -78,7 +78,7 @@ object CometArithmeticBenchmark extends CometBenchmarkBase {
7878
useDictionary: Boolean): Unit = {
7979
val benchmark = new Benchmark(
8080
s"Binary op ${dataType.sql}, dictionary = $useDictionary",
81-
values,
81+
values.toLong,
8282
output = output)
8383
val df = makeDecimalDataFrame(values, dataType, useDictionary)
8484

spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ trait CometBenchmarkBase extends SqlBasedBenchmark {
8181
withTempTable(tbl) {
8282
import spark.implicits._
8383
spark
84-
.range(values)
84+
.range(values.toLong)
8585
.map(_ => if (useDictionary) Random.nextLong % 5 else Random.nextLong)
8686
.createOrReplaceTempView(tbl)
8787
runBenchmark(benchmarkName)(f(values))
@@ -168,7 +168,7 @@ trait CometBenchmarkBase extends SqlBasedBenchmark {
168168

169169
val div = if (useDictionary) 5 else values
170170
spark
171-
.range(values)
171+
.range(values.toLong)
172172
.map(_ % div)
173173
.select((($"value" - 500) / 100.0) cast decimal as Symbol("dec"))
174174
}

spark/src/test/scala/org/apache/spark/sql/benchmark/CometConditionalExpressionBenchmark.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.comet.CometConf
3232
object CometConditionalExpressionBenchmark extends CometBenchmarkBase {
3333

3434
def caseWhenExprBenchmark(values: Int): Unit = {
35-
val benchmark = new Benchmark("Case When Expr", values, output = output)
35+
val benchmark = new Benchmark("Case When Expr", values.toLong, output = output)
3636

3737
withTempPath { dir =>
3838
withTempTable("parquetV1Table") {
@@ -65,7 +65,7 @@ object CometConditionalExpressionBenchmark extends CometBenchmarkBase {
6565
}
6666

6767
def ifExprBenchmark(values: Int): Unit = {
68-
val benchmark = new Benchmark("If Expr", values, output = output)
68+
val benchmark = new Benchmark("If Expr", values.toLong, output = output)
6969

7070
withTempPath { dir =>
7171
withTempTable("parquetV1Table") {

spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ object CometDatetimeExpressionBenchmark extends CometBenchmarkBase {
3939
s"select cast(timestamp_micros(cast(value/100000 as integer)) as date) as dt FROM $tbl"))
4040
Seq("YEAR", "YYYY", "YY", "MON", "MONTH", "MM").foreach { level =>
4141
val isDictionary = if (useDictionary) "(Dictionary)" else ""
42-
runWithComet(s"Date Truncate $isDictionary - $level", values) {
42+
runWithComet(s"Date Truncate $isDictionary - $level", values.toLong) {
4343
spark.sql(s"select trunc(dt, '$level') from parquetV1Table").noop()
4444
}
4545
}
@@ -68,7 +68,7 @@ object CometDatetimeExpressionBenchmark extends CometBenchmarkBase {
6868
"WEEK",
6969
"QUARTER").foreach { level =>
7070
val isDictionary = if (useDictionary) "(Dictionary)" else ""
71-
runWithComet(s"Timestamp Truncate $isDictionary - $level", values) {
71+
runWithComet(s"Timestamp Truncate $isDictionary - $level", values.toLong) {
7272
spark.sql(s"select date_trunc('$level', ts) from parquetV1Table").noop()
7373
}
7474
}

0 commit comments

Comments
 (0)