1919
2020package org .apache .comet .fuzz
2121
22- import java .io .{BufferedWriter , File , FileWriter , PrintWriter , StringWriter }
22+ import java .io .{BufferedWriter , FileWriter , PrintWriter , StringWriter }
2323
2424import scala .collection .mutable
2525import scala .io .Source
@@ -37,7 +37,7 @@ object QueryRunner {
3737 new BufferedWriter (new FileWriter (outputFilename))
3838 }
3939
40- def assertCorrectness (
40+ def executeSQLAndAssertCorrectness (
4141 spark : SparkSession ,
4242 sql : String ,
4343 showFailedSparkQueries : Boolean = false ,
@@ -59,35 +59,17 @@ object QueryRunner {
5959 val cometRows = df.collect()
6060 val cometPlan = df.queryExecution.executedPlan.toString
6161
62- if (sparkRows.length == cometRows.length) {
63- var i = 0
64- while (i < sparkRows.length) {
65- val l = sparkRows(i)
66- val r = cometRows(i)
67- assert(l.length == r.length)
68- for (j <- 0 until l.length) {
69- if (! same(l(j), r(j))) {
70- showSQL(output, sql)
71- showPlans(output, sparkPlan, cometPlan)
72- output.write(s " First difference at row $i: \n " )
73- output.write(" Spark: `" + formatRow(l) + " `\n " )
74- output.write(" Comet: `" + formatRow(r) + " `\n " )
75- i = sparkRows.length
76- }
77- }
78- i += 1
79- }
80- } else {
81- showSQL(output, sql)
82- showPlans(output, sparkPlan, cometPlan)
83- output.write(
84- s " [ERROR] Spark produced ${sparkRows.length} rows and " +
85- s " Comet produced ${cometRows.length} rows. \n " )
86- }
62+ QueryComparison .assertSameRows(
63+ sparkRows,
64+ cometRows,
65+ sqlText = sql,
66+ sparkPlan,
67+ cometPlan,
68+ output)
8769 } catch {
8870 case e : Exception =>
8971 // the query worked in Spark but failed in Comet, so this is likely a bug in Comet
90- showSQL(output, sql)
72+ QueryComparison . showSQL(output, sql)
9173 output.write(s " [ERROR] Query failed in Comet: ${e.getMessage}: \n " )
9274 output.write(" ```\n " )
9375 val sw = new StringWriter ()
@@ -105,7 +87,7 @@ object QueryRunner {
10587 case e : Exception =>
10688 // we expect many generated queries to be invalid
10789 if (showFailedSparkQueries) {
108- showSQL(output, sql)
90+ QueryComparison . showSQL(output, sql)
10991 output.write(s " Query failed in Spark: ${e.getMessage}\n " )
11092 }
11193 }
@@ -133,67 +115,18 @@ object QueryRunner {
133115 try {
134116 querySource
135117 .getLines()
136- .foreach(sql => assertCorrectness(spark, sql, showFailedSparkQueries, output = w))
118+ .foreach(sql =>
119+ executeSQLAndAssertCorrectness(spark, sql, showFailedSparkQueries, output = w))
137120
138121 } finally {
139122 w.close()
140123 querySource.close()
141124 }
142125 }
143126
144- def runTPCQueries (
145- spark : SparkSession ,
146- dataFolderName : String ,
147- queriesFolderName : String ): Unit = {
148- val output = QueryRunner .createOutputMdFile()
149-
150- // Load data tables from dataFolder
151- val dataFolder = new File (dataFolderName)
152- if (! dataFolder.exists() || ! dataFolder.isDirectory) {
153- // scalastyle:off println
154- println(s " Error: Data folder $dataFolder does not exist or is not a directory " )
155- // scalastyle:on println
156- sys.exit(- 1 )
157- }
158-
159- // Traverse data folder and create temp views
160- dataFolder.listFiles().filter(_.isDirectory).foreach { tableDir =>
161- val tableName = tableDir.getName
162- val parquetPath = s " ${tableDir.getAbsolutePath}/*.parquet "
163- spark.read.parquet(parquetPath).createOrReplaceTempView(tableName)
164- // scalastyle:off println
165- println(s " Created temp view: $tableName from $parquetPath" )
166- // scalastyle:on println
167- }
168-
169- // Load and run queries from queriesFolder
170- val queriesFolder = new File (queriesFolderName)
171- if (! queriesFolder.exists() || ! queriesFolder.isDirectory) {
172- // scalastyle:off println
173- println(s " Error: Queries folder $queriesFolder does not exist or is not a directory " )
174- // scalastyle:on println
175- sys.exit(- 1 )
176- }
177-
178- // Traverse queries folder and run each .sql file
179- queriesFolder.listFiles().filter(f => f.isFile && f.getName.endsWith(" .sql" )).foreach {
180- sqlFile =>
181- // scalastyle:off println
182- println(s " Running query from: ${sqlFile.getName}" )
183- // scalastyle:on println
184-
185- val querySource = Source .fromFile(sqlFile)
186- try {
187- val sql = querySource.mkString
188- QueryRunner .assertCorrectness(spark, sql, showFailedSparkQueries = false , output)
189- } finally {
190- querySource.close()
191- }
192- }
193-
194- output.close()
195- }
127+ }
196128
129+ object QueryComparison {
197130 private def same (l : Any , r : Any ): Boolean = {
198131 if (l == null || r == null ) {
199132 return l == null && r == null
@@ -235,7 +168,7 @@ object QueryRunner {
235168 row.toSeq.map(format).mkString(" ," )
236169 }
237170
238- private def showSQL (w : BufferedWriter , sql : String , maxLength : Int = 120 ): Unit = {
171+ def showSQL (w : BufferedWriter , sql : String , maxLength : Int = 120 ): Unit = {
239172 w.write(" ## SQL\n " )
240173 w.write(" ```\n " )
241174 val words = sql.split(" " )
@@ -262,4 +195,37 @@ object QueryRunner {
262195 w.write(s " ``` \n $cometPlan\n ``` \n " )
263196 }
264197
198+ def assertSameRows (
199+ sparkRows : Array [Row ],
200+ cometRows : Array [Row ],
201+ sqlText : String ,
202+ sparkPlan : String ,
203+ cometPlan : String ,
204+ output : BufferedWriter ): Unit = {
205+ if (sparkRows.length == cometRows.length) {
206+ var i = 0
207+ while (i < sparkRows.length) {
208+ val l = sparkRows(i)
209+ val r = cometRows(i)
210+ assert(l.length == r.length)
211+ for (j <- 0 until l.length) {
212+ if (! same(l(j), r(j))) {
213+ showSQL(output, sqlText)
214+ showPlans(output, sparkPlan, cometPlan)
215+ output.write(s " First difference at row $i: \n " )
216+ output.write(" Spark: `" + formatRow(l) + " `\n " )
217+ output.write(" Comet: `" + formatRow(r) + " `\n " )
218+ i = sparkRows.length
219+ }
220+ }
221+ i += 1
222+ }
223+ } else {
224+ showSQL(output, sqlText)
225+ showPlans(output, sparkPlan, cometPlan)
226+ output.write(
227+ s " [ERROR] Spark produced ${sparkRows.length} rows and " +
228+ s " Comet produced ${cometRows.length} rows. \n " )
229+ }
230+ }
265231}
0 commit comments