Skip to content

Commit 383b56d

Browse files
authored
chore: Add tolerance for ComparisonTool (#2699)
* chore: Add tolerance for `ComparisonTool` to identify error threshold for floating point comparisons
1 parent c6136aa commit 383b56d

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

fuzz-testing/src/main/scala/org/apache/comet/fuzz/ComparisonTool.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class ComparisonToolConf(arguments: Seq[String]) extends ScallopConf(arguments)
3131
opt[String](required = true, descr = "Folder with Spark produced results in Parquet format")
3232
val inputCometFolder: ScallopOption[String] =
3333
opt[String](required = true, descr = "Folder with Comet produced results in Parquet format")
34+
val tolerance: ScallopOption[Double] =
35+
opt[Double](default = Some(0.000002), descr = "Tolerance for floating point comparisons")
3436
}
3537
addSubcommand(compareParquet)
3638
verify()
@@ -49,7 +51,8 @@ object ComparisonTool {
4951
compareParquetFolders(
5052
spark,
5153
conf.compareParquet.inputSparkFolder(),
52-
conf.compareParquet.inputCometFolder())
54+
conf.compareParquet.inputCometFolder(),
55+
conf.compareParquet.tolerance())
5356

5457
case _ =>
5558
// scalastyle:off println
@@ -62,7 +65,8 @@ object ComparisonTool {
6265
private def compareParquetFolders(
6366
spark: SparkSession,
6467
sparkFolderPath: String,
65-
cometFolderPath: String): Unit = {
68+
cometFolderPath: String,
69+
tolerance: Double): Unit = {
6670

6771
val output = QueryRunner.createOutputMdFile()
6872

@@ -115,7 +119,7 @@ object ComparisonTool {
115119
val cometRows = cometDf.orderBy(cometDf.columns.map(functions.col): _*).collect()
116120

117121
// Compare the results
118-
if (QueryComparison.assertSameRows(sparkRows, cometRows, output)) {
122+
if (QueryComparison.assertSameRows(sparkRows, cometRows, output, tolerance)) {
119123
output.write(s"Subfolder $subfolderName: ${sparkRows.length} rows matched\n\n")
120124
} else {
121125
// Output schema if dataframes are not equal

fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ object QueryComparison {
148148
def assertSameRows(
149149
sparkRows: Array[Row],
150150
cometRows: Array[Row],
151-
output: BufferedWriter): Boolean = {
151+
output: BufferedWriter,
152+
tolerance: Double = 0.000001): Boolean = {
152153
if (sparkRows.length == cometRows.length) {
153154
var i = 0
154155
while (i < sparkRows.length) {
@@ -164,7 +165,7 @@ object QueryComparison {
164165

165166
assert(l.length == r.length)
166167
for (j <- 0 until l.length) {
167-
if (!same(l(j), r(j))) {
168+
if (!same(l(j), r(j), tolerance)) {
168169
output.write(s"First difference at row $i:\n")
169170
output.write("Spark: `" + formatRow(l) + "`\n")
170171
output.write("Comet: `" + formatRow(r) + "`\n")
@@ -186,7 +187,7 @@ object QueryComparison {
186187
true
187188
}
188189

189-
private def same(l: Any, r: Any): Boolean = {
190+
private def same(l: Any, r: Any, tolerance: Double): Boolean = {
190191
if (l == null || r == null) {
191192
return l == null && r == null
192193
}
@@ -195,20 +196,20 @@ object QueryComparison {
195196
case (a: Float, b: Float) if a.isNegInfinity => b.isNegInfinity
196197
case (a: Float, b: Float) if a.isInfinity => b.isInfinity
197198
case (a: Float, b: Float) if a.isNaN => b.isNaN
198-
case (a: Float, b: Float) => (a - b).abs <= 0.000001f
199+
case (a: Float, b: Float) => (a - b).abs <= tolerance
199200
case (a: Double, b: Double) if a.isPosInfinity => b.isPosInfinity
200201
case (a: Double, b: Double) if a.isNegInfinity => b.isNegInfinity
201202
case (a: Double, b: Double) if a.isInfinity => b.isInfinity
202203
case (a: Double, b: Double) if a.isNaN => b.isNaN
203-
case (a: Double, b: Double) => (a - b).abs <= 0.000001
204+
case (a: Double, b: Double) => (a - b).abs <= tolerance
204205
case (a: Array[_], b: Array[_]) =>
205-
a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
206+
a.length == b.length && a.zip(b).forall(x => same(x._1, x._2, tolerance))
206207
case (a: mutable.WrappedArray[_], b: mutable.WrappedArray[_]) =>
207-
a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
208+
a.length == b.length && a.zip(b).forall(x => same(x._1, x._2, tolerance))
208209
case (a: Row, b: Row) =>
209210
val aa = a.toSeq
210211
val bb = b.toSeq
211-
aa.length == bb.length && aa.zip(bb).forall(x => same(x._1, x._2))
212+
aa.length == bb.length && aa.zip(bb).forall(x => same(x._1, x._2, tolerance))
212213
case (a, b) => a == b
213214
}
214215
}

0 commit comments

Comments
 (0)