@@ -148,7 +148,8 @@ object QueryComparison {
148148 def assertSameRows (
149149 sparkRows : Array [Row ],
150150 cometRows : Array [Row ],
151- output : BufferedWriter ): Boolean = {
151+ output : BufferedWriter ,
152+ tolerance : Double = 0.000001 ): Boolean = {
152153 if (sparkRows.length == cometRows.length) {
153154 var i = 0
154155 while (i < sparkRows.length) {
@@ -164,7 +165,7 @@ object QueryComparison {
164165
165166 assert(l.length == r.length)
166167 for (j <- 0 until l.length) {
167- if (! same(l(j), r(j))) {
168+ if (! same(l(j), r(j), tolerance )) {
168169 output.write(s " First difference at row $i: \n " )
169170 output.write(" Spark: `" + formatRow(l) + " `\n " )
170171 output.write(" Comet: `" + formatRow(r) + " `\n " )
@@ -186,7 +187,7 @@ object QueryComparison {
186187 true
187188 }
188189
189- private def same (l : Any , r : Any ): Boolean = {
190+ private def same (l : Any , r : Any , tolerance : Double ): Boolean = {
190191 if (l == null || r == null ) {
191192 return l == null && r == null
192193 }
@@ -195,20 +196,20 @@ object QueryComparison {
195196 case (a : Float , b : Float ) if a.isNegInfinity => b.isNegInfinity
196197 case (a : Float , b : Float ) if a.isInfinity => b.isInfinity
197198 case (a : Float , b : Float ) if a.isNaN => b.isNaN
198- case (a : Float , b : Float ) => (a - b).abs <= 0.000001f
199+ case (a : Float , b : Float ) => (a - b).abs <= tolerance
199200 case (a : Double , b : Double ) if a.isPosInfinity => b.isPosInfinity
200201 case (a : Double , b : Double ) if a.isNegInfinity => b.isNegInfinity
201202 case (a : Double , b : Double ) if a.isInfinity => b.isInfinity
202203 case (a : Double , b : Double ) if a.isNaN => b.isNaN
203- case (a : Double , b : Double ) => (a - b).abs <= 0.000001
204+ case (a : Double , b : Double ) => (a - b).abs <= tolerance
204205 case (a : Array [_], b : Array [_]) =>
205- a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
206+ a.length == b.length && a.zip(b).forall(x => same(x._1, x._2, tolerance ))
206207 case (a : mutable.WrappedArray [_], b : mutable.WrappedArray [_]) =>
207- a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
208+ a.length == b.length && a.zip(b).forall(x => same(x._1, x._2, tolerance ))
208209 case (a : Row , b : Row ) =>
209210 val aa = a.toSeq
210211 val bb = b.toSeq
211- aa.length == bb.length && aa.zip(bb).forall(x => same(x._1, x._2))
212+ aa.length == bb.length && aa.zip(bb).forall(x => same(x._1, x._2, tolerance ))
212213 case (a, b) => a == b
213214 }
214215 }
0 commit comments