1919
2020package org .apache .comet .fuzz
2121
22- import java .io .{BufferedWriter , FileWriter , PrintWriter , StringWriter }
22+ import java .io .{BufferedWriter , File , FileWriter , PrintWriter , StringWriter }
2323
24- import scala .collection .mutable . WrappedArray
24+ import scala .collection .mutable
2525import scala .io .Source
2626
2727import org .apache .spark .sql .{Row , SparkSession }
2828
2929object QueryRunner {
3030
31+ def createOutputMdFile (): BufferedWriter = {
32+ val outputFilename = s " results- ${System .currentTimeMillis()}.md "
33+ // scalastyle:off println
34+ println(s " Writing results to $outputFilename" )
35+ // scalastyle:on println
36+
37+ new BufferedWriter (new FileWriter (outputFilename))
38+ }
39+
40+ def assertCorrectness (
41+ spark : SparkSession ,
42+ sql : String ,
43+ showFailedSparkQueries : Boolean = false ,
44+ output : BufferedWriter ): Unit = {
45+ try {
46+ // execute with Spark
47+ spark.conf.set(" spark.comet.enabled" , " false" )
48+ val df = spark.sql(sql)
49+ val sparkRows = df.collect()
50+ val sparkPlan = df.queryExecution.executedPlan.toString
51+
52+ // execute with Comet
53+ try {
54+ spark.conf.set(" spark.comet.enabled" , " true" )
55+ // complex type support until we support it natively
56+ spark.conf.set(" spark.comet.sparkToColumnar.enabled" , " true" )
57+ spark.conf.set(" spark.comet.convert.parquet.enabled" , " true" )
58+ val df = spark.sql(sql)
59+ val cometRows = df.collect()
60+ val cometPlan = df.queryExecution.executedPlan.toString
61+
62+ if (sparkRows.length == cometRows.length) {
63+ var i = 0
64+ while (i < sparkRows.length) {
65+ val l = sparkRows(i)
66+ val r = cometRows(i)
67+ assert(l.length == r.length)
68+ for (j <- 0 until l.length) {
69+ if (! same(l(j), r(j))) {
70+ showSQL(output, sql)
71+ showPlans(output, sparkPlan, cometPlan)
72+ output.write(s " First difference at row $i: \n " )
73+ output.write(" Spark: `" + formatRow(l) + " `\n " )
74+ output.write(" Comet: `" + formatRow(r) + " `\n " )
75+ i = sparkRows.length
76+ }
77+ }
78+ i += 1
79+ }
80+ } else {
81+ showSQL(output, sql)
82+ showPlans(output, sparkPlan, cometPlan)
83+ output.write(
84+ s " [ERROR] Spark produced ${sparkRows.length} rows and " +
85+ s " Comet produced ${cometRows.length} rows. \n " )
86+ }
87+ } catch {
88+ case e : Exception =>
89+ // the query worked in Spark but failed in Comet, so this is likely a bug in Comet
90+ showSQL(output, sql)
91+ output.write(s " [ERROR] Query failed in Comet: ${e.getMessage}: \n " )
92+ output.write(" ```\n " )
93+ val sw = new StringWriter ()
94+ val p = new PrintWriter (sw)
95+ e.printStackTrace(p)
96+ p.close()
97+ output.write(s " ${sw.toString}\n " )
98+ output.write(" ```\n " )
99+ }
100+
101+ // flush after every query so that results are saved in the event of the driver crashing
102+ output.flush()
103+
104+ } catch {
105+ case e : Exception =>
106+ // we expect many generated queries to be invalid
107+ if (showFailedSparkQueries) {
108+ showSQL(output, sql)
109+ output.write(s " Query failed in Spark: ${e.getMessage}\n " )
110+ }
111+ }
112+ }
113+
31114 def runQueries (
32115 spark : SparkSession ,
33116 numFiles : Int ,
34117 filename : String ,
35118 showFailedSparkQueries : Boolean = false ): Unit = {
36119
37- val outputFilename = s " results- ${System .currentTimeMillis()}.md "
38- // scalastyle:off println
39- println(s " Writing results to $outputFilename" )
40- // scalastyle:on println
41-
42- val w = new BufferedWriter (new FileWriter (outputFilename))
120+ val w = createOutputMdFile()
43121
44122 // register input files
45123 for (i <- 0 until numFiles) {
@@ -55,104 +133,100 @@ object QueryRunner {
55133 try {
56134 querySource
57135 .getLines()
58- .foreach(sql => {
59-
60- try {
61- // execute with Spark
62- spark.conf.set(" spark.comet.enabled" , " false" )
63- val df = spark.sql(sql)
64- val sparkRows = df.collect()
65- val sparkPlan = df.queryExecution.executedPlan.toString
66-
67- // execute with Comet
68- try {
69- spark.conf.set(" spark.comet.enabled" , " true" )
70- // complex type support until we support it natively
71- spark.conf.set(" spark.comet.sparkToColumnar.enabled" , " true" )
72- spark.conf.set(" spark.comet.convert.parquet.enabled" , " true" )
73- val df = spark.sql(sql)
74- val cometRows = df.collect()
75- val cometPlan = df.queryExecution.executedPlan.toString
76-
77- if (sparkRows.length == cometRows.length) {
78- var i = 0
79- while (i < sparkRows.length) {
80- val l = sparkRows(i)
81- val r = cometRows(i)
82- assert(l.length == r.length)
83- for (j <- 0 until l.length) {
84- if (! same(l(j), r(j))) {
85- showSQL(w, sql)
86- showPlans(w, sparkPlan, cometPlan)
87- w.write(s " First difference at row $i: \n " )
88- w.write(" Spark: `" + formatRow(l) + " `\n " )
89- w.write(" Comet: `" + formatRow(r) + " `\n " )
90- i = sparkRows.length
91- }
92- }
93- i += 1
94- }
95- } else {
96- showSQL(w, sql)
97- showPlans(w, sparkPlan, cometPlan)
98- w.write(
99- s " [ERROR] Spark produced ${sparkRows.length} rows and " +
100- s " Comet produced ${cometRows.length} rows. \n " )
101- }
102- } catch {
103- case e : Exception =>
104- // the query worked in Spark but failed in Comet, so this is likely a bug in Comet
105- showSQL(w, sql)
106- w.write(s " [ERROR] Query failed in Comet: ${e.getMessage}: \n " )
107- w.write(" ```\n " )
108- val sw = new StringWriter ()
109- val p = new PrintWriter (sw)
110- e.printStackTrace(p)
111- p.close()
112- w.write(s " ${sw.toString}\n " )
113- w.write(" ```\n " )
114- }
115-
116- // flush after every query so that results are saved in the event of the driver crashing
117- w.flush()
118-
119- } catch {
120- case e : Exception =>
121- // we expect many generated queries to be invalid
122- if (showFailedSparkQueries) {
123- showSQL(w, sql)
124- w.write(s " Query failed in Spark: ${e.getMessage}\n " )
125- }
126- }
127- })
136+ .foreach(sql => assertCorrectness(spark, sql, showFailedSparkQueries, output = w))
128137
129138 } finally {
130139 w.close()
131140 querySource.close()
132141 }
133142 }
134143
144+ def runTPCQueries (
145+ spark : SparkSession ,
146+ dataFolderName : String ,
147+ queriesFolderName : String ): Unit = {
148+ val output = QueryRunner .createOutputMdFile()
149+
150+ // Load data tables from dataFolder
151+ val dataFolder = new File (dataFolderName)
152+ if (! dataFolder.exists() || ! dataFolder.isDirectory) {
153+ // scalastyle:off println
154+ println(s " Error: Data folder $dataFolder does not exist or is not a directory " )
155+ // scalastyle:on println
156+ sys.exit(- 1 )
157+ }
158+
159+ // Traverse data folder and create temp views
160+ dataFolder.listFiles().filter(_.isDirectory).foreach { tableDir =>
161+ val tableName = tableDir.getName
162+ val parquetPath = s " ${tableDir.getAbsolutePath}/*.parquet "
163+ spark.read.parquet(parquetPath).createOrReplaceTempView(tableName)
164+ // scalastyle:off println
165+ println(s " Created temp view: $tableName from $parquetPath" )
166+ // scalastyle:on println
167+ }
168+
169+ // Load and run queries from queriesFolder
170+ val queriesFolder = new File (queriesFolderName)
171+ if (! queriesFolder.exists() || ! queriesFolder.isDirectory) {
172+ // scalastyle:off println
173+ println(s " Error: Queries folder $queriesFolder does not exist or is not a directory " )
174+ // scalastyle:on println
175+ sys.exit(- 1 )
176+ }
177+
178+ // Traverse queries folder and run each .sql file
179+ queriesFolder.listFiles().filter(f => f.isFile && f.getName.endsWith(" .sql" )).foreach {
180+ sqlFile =>
181+ // scalastyle:off println
182+ println(s " Running query from: ${sqlFile.getName}" )
183+ // scalastyle:on println
184+
185+ val querySource = Source .fromFile(sqlFile)
186+ try {
187+ val sql = querySource.mkString
188+ QueryRunner .assertCorrectness(spark, sql, showFailedSparkQueries = false , output)
189+ } finally {
190+ querySource.close()
191+ }
192+ }
193+
194+ output.close()
195+ }
196+
135197 private def same (l : Any , r : Any ): Boolean = {
198+ if (l == null || r == null ) {
199+ return l == null && r == null
200+ }
136201 (l, r) match {
202+ case (a : Float , b : Float ) if a.isPosInfinity => b.isPosInfinity
203+ case (a : Float , b : Float ) if a.isNegInfinity => b.isNegInfinity
137204 case (a : Float , b : Float ) if a.isInfinity => b.isInfinity
138205 case (a : Float , b : Float ) if a.isNaN => b.isNaN
139206 case (a : Float , b : Float ) => (a - b).abs <= 0.000001f
207+ case (a : Double , b : Double ) if a.isPosInfinity => b.isPosInfinity
208+ case (a : Double , b : Double ) if a.isNegInfinity => b.isNegInfinity
140209 case (a : Double , b : Double ) if a.isInfinity => b.isInfinity
141210 case (a : Double , b : Double ) if a.isNaN => b.isNaN
142211 case (a : Double , b : Double ) => (a - b).abs <= 0.000001
143212 case (a : Array [_], b : Array [_]) =>
144213 a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
145- case (a : WrappedArray [_], b : WrappedArray [_]) =>
214+ case (a : mutable. WrappedArray [_], b : mutable. WrappedArray [_]) =>
146215 a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
216+ case (a : Row , b : Row ) =>
217+ val aa = a.toSeq
218+ val bb = b.toSeq
219+ aa.length == bb.length && aa.zip(bb).forall(x => same(x._1, x._2))
147220 case (a, b) => a == b
148221 }
149222 }
150223
151224 private def format (value : Any ): String = {
152225 value match {
153226 case null => " NULL"
154- case v : WrappedArray [_] => s " [ ${v.map(format).mkString(" ," )}] "
227+ case v : mutable. WrappedArray [_] => s " [ ${v.map(format).mkString(" ," )}] "
155228 case v : Array [Byte ] => s " [ ${v.mkString(" ," )}] "
229+ case r : Row => formatRow(r)
156230 case other => other.toString
157231 }
158232 }
0 commit comments