Skip to content

Commit 7fd5797

Browse files
authored
minor: Improve testing of math scalar functions (#1896)
1 parent 8f94c25 commit 7fd5797

File tree

2 files changed

+44
-27
lines changed

2 files changed

+44
-27
lines changed

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,38 +1241,54 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12411241
}
12421242
}
12431243

1244+
private val doubleValues: Seq[Double] = Seq(
1245+
-1.0,
1246+
// TODO we should eventually enable negative zero but there are known issues still
1247+
// -0.0,
1248+
0.0,
1249+
+1.0,
1250+
Double.MinValue,
1251+
Double.MaxValue,
1252+
Double.NaN,
1253+
Double.MinPositiveValue,
1254+
Double.PositiveInfinity,
1255+
Double.NegativeInfinity)
1256+
12441257
test("various math scalar functions") {
1245-
Seq("true", "false").foreach { dictionary =>
1246-
withSQLConf("parquet.enable.dictionary" -> dictionary) {
1247-
withParquetTable(
1248-
(-5 until 5).map(i => (i.toDouble + 0.3, i.toDouble + 0.8)),
1249-
"tbl",
1250-
withDictionary = dictionary.toBoolean) {
1251-
checkSparkAnswerWithTol(
1252-
"SELECT abs(_1), acos(_2), asin(_1), atan(_2), atan2(_1, _2), cos(_1) FROM tbl")
1253-
checkSparkAnswerWithTol(
1254-
"SELECT exp(_1), ln(_2), log10(_1), log2(_1), pow(_1, _2) FROM tbl")
1255-
// TODO: comment in the round tests once supported
1256-
// checkSparkAnswerWithTol("SELECT round(_1), round(_2) FROM tbl")
1257-
checkSparkAnswerWithTol("SELECT signum(_1), sin(_1), sqrt(_1) FROM tbl")
1258-
checkSparkAnswerWithTol("SELECT tan(_1) FROM tbl")
1258+
val data = doubleValues.map(n => (n, n))
1259+
withParquetTable(data, "tbl") {
1260+
// expressions with single arg
1261+
for (expr <- Seq(
1262+
"acos",
1263+
"asin",
1264+
"atan",
1265+
"cos",
1266+
"exp",
1267+
"ln",
1268+
"log10",
1269+
"log2",
1270+
"sin",
1271+
"sqrt",
1272+
"tan")) {
1273+
val df = checkSparkAnswerWithTol(s"SELECT $expr(_1), $expr(_2) FROM tbl")
1274+
val cometProjectExecs = collect(df.queryExecution.executedPlan) {
1275+
case op: CometProjectExec => op
1276+
}
1277+
assert(cometProjectExecs.length == 1, expr)
1278+
}
1279+
// expressions with two args
1280+
for (expr <- Seq("atan2", "pow")) {
1281+
val df = checkSparkAnswerWithTol(s"SELECT $expr(_1, _2) FROM tbl")
1282+
val cometProjectExecs = collect(df.queryExecution.executedPlan) {
1283+
case op: CometProjectExec => op
12591284
}
1285+
assert(cometProjectExecs.length == 1, expr)
12601286
}
12611287
}
12621288
}
12631289

12641290
test("expm1") {
1265-
val testValues = Seq(
1266-
-1,
1267-
0,
1268-
+1,
1269-
Double.MinValue,
1270-
Double.MaxValue,
1271-
Double.NaN,
1272-
Double.MinPositiveValue,
1273-
Double.PositiveInfinity,
1274-
Double.NegativeInfinity)
1275-
val testValuesRepeated = testValues.flatMap(v => Seq.fill(1000)(v))
1291+
val testValuesRepeated = doubleValues.flatMap(v => Seq.fill(1000)(v))
12761292
withParquetTable(testValuesRepeated.map(n => (n, n)), "tbl") {
12771293
checkSparkAnswerWithTol("SELECT expm1(_1) FROM tbl")
12781294
}

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,21 +219,22 @@ abstract class CometTestBase
219219
/**
220220
* Check the answer of a Comet SQL query with Spark result using absolute tolerance.
221221
*/
222-
protected def checkSparkAnswerWithTol(query: String, absTol: Double = 1e-6): Unit = {
222+
protected def checkSparkAnswerWithTol(query: String, absTol: Double = 1e-6): DataFrame = {
223223
checkSparkAnswerWithTol(sql(query), absTol)
224224
}
225225

226226
/**
227227
* Check the answer of a Comet DataFrame with Spark result using absolute tolerance.
228228
*/
229-
protected def checkSparkAnswerWithTol(df: => DataFrame, absTol: Double): Unit = {
229+
protected def checkSparkAnswerWithTol(df: => DataFrame, absTol: Double): DataFrame = {
230230
var expected: Array[Row] = Array.empty
231231
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
232232
val dfSpark = Dataset.ofRows(spark, df.logicalPlan)
233233
expected = dfSpark.collect()
234234
}
235235
val dfComet = Dataset.ofRows(spark, df.logicalPlan)
236236
checkAnswerWithTol(dfComet, expected, absTol: Double)
237+
dfComet
237238
}
238239

239240
protected def checkSparkMaybeThrows(

0 commit comments

Comments
 (0)