Skip to content

Commit 91f9ff1

Browse files
committed
chore: add TPC queries to be run by fuzzer correctness checker
1 parent eeb1566 commit 91f9ff1

File tree

2 files changed

+170
-80
lines changed

2 files changed

+170
-80
lines changed

fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {
6161
opt[Int](required = false, descr = "Number of input files to use")
6262
}
6363
addSubcommand(runQueries)
64+
object runTPCQueries extends Subcommand("runTPC") {
65+
val dataFolder: ScallopOption[String] =
66+
opt[String](
67+
required = true,
68+
descr = "Folder for input data. Expected folder struct `$dataFolder/tableName/*.parquet`")
69+
val queriesFolder: ScallopOption[String] =
70+
opt[String](
71+
required = true,
72+
descr = "Folder for test queries. Expected folder struct `$queriesFolder/*.sql`")
73+
}
74+
addSubcommand(runTPCQueries)
6475
verify()
6576
}
6677

@@ -104,6 +115,11 @@ object Main {
104115
conf.generateQueries.numQueries())
105116
case Some(conf.runQueries) =>
106117
QueryRunner.runQueries(spark, conf.runQueries.numFiles(), conf.runQueries.filename())
118+
case Some(conf.runTPCQueries) =>
119+
QueryRunner.runTPCQueries(
120+
spark,
121+
conf.runTPCQueries.dataFolder(),
122+
conf.runTPCQueries.queriesFolder())
107123
case _ =>
108124
// scalastyle:off println
109125
println("Invalid subcommand")

fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala

Lines changed: 154 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,105 @@
1919

2020
package org.apache.comet.fuzz
2121

22-
import java.io.{BufferedWriter, FileWriter, PrintWriter, StringWriter}
22+
import java.io.{BufferedWriter, File, FileWriter, PrintWriter, StringWriter}
2323

24-
import scala.collection.mutable.WrappedArray
24+
import scala.collection.mutable
2525
import scala.io.Source
2626

2727
import org.apache.spark.sql.{Row, SparkSession}
2828

2929
object QueryRunner {
3030

31+
def createOutputMdFile(): BufferedWriter = {
32+
val outputFilename = s"results-${System.currentTimeMillis()}.md"
33+
// scalastyle:off println
34+
println(s"Writing results to $outputFilename")
35+
// scalastyle:on println
36+
37+
new BufferedWriter(new FileWriter(outputFilename))
38+
}
39+
40+
def assertCorrectness(
41+
spark: SparkSession,
42+
sql: String,
43+
showFailedSparkQueries: Boolean = false,
44+
output: BufferedWriter): Unit = {
45+
try {
46+
// execute with Spark
47+
spark.conf.set("spark.comet.enabled", "false")
48+
val df = spark.sql(sql)
49+
val sparkRows = df.collect()
50+
val sparkPlan = df.queryExecution.executedPlan.toString
51+
52+
// execute with Comet
53+
try {
54+
spark.conf.set("spark.comet.enabled", "true")
55+
// complex type support until we support it natively
56+
spark.conf.set("spark.comet.sparkToColumnar.enabled", "true")
57+
spark.conf.set("spark.comet.convert.parquet.enabled", "true")
58+
val df = spark.sql(sql)
59+
val cometRows = df.collect()
60+
val cometPlan = df.queryExecution.executedPlan.toString
61+
62+
if (sparkRows.length == cometRows.length) {
63+
var i = 0
64+
while (i < sparkRows.length) {
65+
val l = sparkRows(i)
66+
val r = cometRows(i)
67+
assert(l.length == r.length)
68+
for (j <- 0 until l.length) {
69+
if (!same(l(j), r(j))) {
70+
showSQL(output, sql)
71+
showPlans(output, sparkPlan, cometPlan)
72+
output.write(s"First difference at row $i:\n")
73+
output.write("Spark: `" + formatRow(l) + "`\n")
74+
output.write("Comet: `" + formatRow(r) + "`\n")
75+
i = sparkRows.length
76+
}
77+
}
78+
i += 1
79+
}
80+
} else {
81+
showSQL(output, sql)
82+
showPlans(output, sparkPlan, cometPlan)
83+
output.write(
84+
s"[ERROR] Spark produced ${sparkRows.length} rows and " +
85+
s"Comet produced ${cometRows.length} rows.\n")
86+
}
87+
} catch {
88+
case e: Exception =>
89+
// the query worked in Spark but failed in Comet, so this is likely a bug in Comet
90+
showSQL(output, sql)
91+
output.write(s"[ERROR] Query failed in Comet: ${e.getMessage}:\n")
92+
output.write("```\n")
93+
val sw = new StringWriter()
94+
val p = new PrintWriter(sw)
95+
e.printStackTrace(p)
96+
p.close()
97+
output.write(s"${sw.toString}\n")
98+
output.write("```\n")
99+
}
100+
101+
// flush after every query so that results are saved in the event of the driver crashing
102+
output.flush()
103+
104+
} catch {
105+
case e: Exception =>
106+
// we expect many generated queries to be invalid
107+
if (showFailedSparkQueries) {
108+
showSQL(output, sql)
109+
output.write(s"Query failed in Spark: ${e.getMessage}\n")
110+
}
111+
}
112+
}
113+
31114
def runQueries(
32115
spark: SparkSession,
33116
numFiles: Int,
34117
filename: String,
35118
showFailedSparkQueries: Boolean = false): Unit = {
36119

37-
val outputFilename = s"results-${System.currentTimeMillis()}.md"
38-
// scalastyle:off println
39-
println(s"Writing results to $outputFilename")
40-
// scalastyle:on println
41-
42-
val w = new BufferedWriter(new FileWriter(outputFilename))
120+
val w = createOutputMdFile()
43121

44122
// register input files
45123
for (i <- 0 until numFiles) {
@@ -55,104 +133,100 @@ object QueryRunner {
55133
try {
56134
querySource
57135
.getLines()
58-
.foreach(sql => {
59-
60-
try {
61-
// execute with Spark
62-
spark.conf.set("spark.comet.enabled", "false")
63-
val df = spark.sql(sql)
64-
val sparkRows = df.collect()
65-
val sparkPlan = df.queryExecution.executedPlan.toString
66-
67-
// execute with Comet
68-
try {
69-
spark.conf.set("spark.comet.enabled", "true")
70-
// complex type support until we support it natively
71-
spark.conf.set("spark.comet.sparkToColumnar.enabled", "true")
72-
spark.conf.set("spark.comet.convert.parquet.enabled", "true")
73-
val df = spark.sql(sql)
74-
val cometRows = df.collect()
75-
val cometPlan = df.queryExecution.executedPlan.toString
76-
77-
if (sparkRows.length == cometRows.length) {
78-
var i = 0
79-
while (i < sparkRows.length) {
80-
val l = sparkRows(i)
81-
val r = cometRows(i)
82-
assert(l.length == r.length)
83-
for (j <- 0 until l.length) {
84-
if (!same(l(j), r(j))) {
85-
showSQL(w, sql)
86-
showPlans(w, sparkPlan, cometPlan)
87-
w.write(s"First difference at row $i:\n")
88-
w.write("Spark: `" + formatRow(l) + "`\n")
89-
w.write("Comet: `" + formatRow(r) + "`\n")
90-
i = sparkRows.length
91-
}
92-
}
93-
i += 1
94-
}
95-
} else {
96-
showSQL(w, sql)
97-
showPlans(w, sparkPlan, cometPlan)
98-
w.write(
99-
s"[ERROR] Spark produced ${sparkRows.length} rows and " +
100-
s"Comet produced ${cometRows.length} rows.\n")
101-
}
102-
} catch {
103-
case e: Exception =>
104-
// the query worked in Spark but failed in Comet, so this is likely a bug in Comet
105-
showSQL(w, sql)
106-
w.write(s"[ERROR] Query failed in Comet: ${e.getMessage}:\n")
107-
w.write("```\n")
108-
val sw = new StringWriter()
109-
val p = new PrintWriter(sw)
110-
e.printStackTrace(p)
111-
p.close()
112-
w.write(s"${sw.toString}\n")
113-
w.write("```\n")
114-
}
115-
116-
// flush after every query so that results are saved in the event of the driver crashing
117-
w.flush()
118-
119-
} catch {
120-
case e: Exception =>
121-
// we expect many generated queries to be invalid
122-
if (showFailedSparkQueries) {
123-
showSQL(w, sql)
124-
w.write(s"Query failed in Spark: ${e.getMessage}\n")
125-
}
126-
}
127-
})
136+
.foreach(sql => assertCorrectness(spark, sql, showFailedSparkQueries, output = w))
128137

129138
} finally {
130139
w.close()
131140
querySource.close()
132141
}
133142
}
134143

144+
def runTPCQueries(
145+
spark: SparkSession,
146+
dataFolderName: String,
147+
queriesFolderName: String): Unit = {
148+
val output = QueryRunner.createOutputMdFile()
149+
150+
// Load data tables from dataFolder
151+
val dataFolder = new File(dataFolderName)
152+
if (!dataFolder.exists() || !dataFolder.isDirectory) {
153+
// scalastyle:off println
154+
println(s"Error: Data folder $dataFolder does not exist or is not a directory")
155+
// scalastyle:on println
156+
sys.exit(-1)
157+
}
158+
159+
// Traverse data folder and create temp views
160+
dataFolder.listFiles().filter(_.isDirectory).foreach { tableDir =>
161+
val tableName = tableDir.getName
162+
val parquetPath = s"${tableDir.getAbsolutePath}/*.parquet"
163+
spark.read.parquet(parquetPath).createOrReplaceTempView(tableName)
164+
// scalastyle:off println
165+
println(s"Created temp view: $tableName from $parquetPath")
166+
// scalastyle:on println
167+
}
168+
169+
// Load and run queries from queriesFolder
170+
val queriesFolder = new File(queriesFolderName)
171+
if (!queriesFolder.exists() || !queriesFolder.isDirectory) {
172+
// scalastyle:off println
173+
println(s"Error: Queries folder $queriesFolder does not exist or is not a directory")
174+
// scalastyle:on println
175+
sys.exit(-1)
176+
}
177+
178+
// Traverse queries folder and run each .sql file
179+
queriesFolder.listFiles().filter(f => f.isFile && f.getName.endsWith(".sql")).foreach {
180+
sqlFile =>
181+
// scalastyle:off println
182+
println(s"Running query from: ${sqlFile.getName}")
183+
// scalastyle:on println
184+
185+
val querySource = Source.fromFile(sqlFile)
186+
try {
187+
val sql = querySource.mkString
188+
QueryRunner.assertCorrectness(spark, sql, showFailedSparkQueries = false, output)
189+
} finally {
190+
querySource.close()
191+
}
192+
}
193+
194+
output.close()
195+
}
196+
135197
private def same(l: Any, r: Any): Boolean = {
198+
if (l == null || r == null) {
199+
return l == null && r == null
200+
}
136201
(l, r) match {
202+
case (a: Float, b: Float) if a.isPosInfinity => b.isPosInfinity
203+
case (a: Float, b: Float) if a.isNegInfinity => b.isNegInfinity
137204
case (a: Float, b: Float) if a.isInfinity => b.isInfinity
138205
case (a: Float, b: Float) if a.isNaN => b.isNaN
139206
case (a: Float, b: Float) => (a - b).abs <= 0.000001f
207+
case (a: Double, b: Double) if a.isPosInfinity => b.isPosInfinity
208+
case (a: Double, b: Double) if a.isNegInfinity => b.isNegInfinity
140209
case (a: Double, b: Double) if a.isInfinity => b.isInfinity
141210
case (a: Double, b: Double) if a.isNaN => b.isNaN
142211
case (a: Double, b: Double) => (a - b).abs <= 0.000001
143212
case (a: Array[_], b: Array[_]) =>
144213
a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
145-
case (a: WrappedArray[_], b: WrappedArray[_]) =>
214+
case (a: mutable.WrappedArray[_], b: mutable.WrappedArray[_]) =>
146215
a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
216+
case (a: Row, b: Row) =>
217+
val aa = a.toSeq
218+
val bb = b.toSeq
219+
aa.length == bb.length && aa.zip(bb).forall(x => same(x._1, x._2))
147220
case (a, b) => a == b
148221
}
149222
}
150223

151224
private def format(value: Any): String = {
152225
value match {
153226
case null => "NULL"
154-
case v: WrappedArray[_] => s"[${v.map(format).mkString(",")}]"
227+
case v: mutable.WrappedArray[_] => s"[${v.map(format).mkString(",")}]"
155228
case v: Array[Byte] => s"[${v.mkString(",")}]"
229+
case r: Row => formatRow(r)
156230
case other => other.toString
157231
}
158232
}

0 commit comments

Comments
 (0)