Skip to content

Commit 941c300

Browse files
authored
chore: generate Float/Double NaN (apache#2695)
1 parent c91036e commit 941c300

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ object FuzzDataGenerator {
4444
val defaultBaseDate: Long =
4545
new SimpleDateFormat("YYYY-MM-DD hh:mm:ss").parse("3333-05-25 12:34:56").getTime
4646

47+
val floatNaNLiteral = "FLOAT('NaN')"
48+
val doubleNaNLiteral = "DOUBLE('NaN')"
49+
4750
def generateSchema(options: SchemaGenOptions): StructType = {
4851
val primitiveTypes = options.primitiveTypes
4952
val dataTypes = ListBuffer[DataType]()
@@ -168,6 +171,7 @@ object FuzzDataGenerator {
168171
case 4 => Float.MaxValue
169172
case 5 => 0.0f
170173
case 6 if options.generateNegativeZero => -0.0f
174+
case 7 if options.generateNaN => Float.NaN
171175
case _ => r.nextFloat()
172176
}
173177
})
@@ -181,6 +185,7 @@ object FuzzDataGenerator {
181185
case 4 => Double.MaxValue
182186
case 5 => 0.0
183187
case 6 if options.generateNegativeZero => -0.0
188+
case 7 if options.generateNaN => Double.NaN
184189
case _ => r.nextDouble()
185190
}
186191
})
@@ -257,6 +262,7 @@ case class SchemaGenOptions(
257262
case class DataGenOptions(
258263
allowNull: Boolean = true,
259264
generateNegativeZero: Boolean = true,
265+
generateNaN: Boolean = true,
260266
baseDate: Long = FuzzDataGenerator.defaultBaseDate,
261267
customStrings: Seq[String] = Seq.empty,
262268
maxStringLength: Int = 8)

spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.sql.types._
2929

3030
import org.apache.comet.DataTypeSupport.isComplexType
3131
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions}
32+
import org.apache.comet.testing.FuzzDataGenerator.{doubleNaNLiteral, floatNaNLiteral}
3233

3334
class CometFuzzTestSuite extends CometFuzzTestBase {
3435

@@ -71,8 +72,20 @@ class CometFuzzTestSuite extends CometFuzzTestBase {
7172
// Construct the string for the default value based on the column type.
7273
val defaultValueString = defaultValueType match {
7374
// These explicit type definitions for TINYINT, SMALLINT, FLOAT, DOUBLE, and DATE are only needed for 3.4.
74-
case "TINYINT" | "SMALLINT" | "FLOAT" | "DOUBLE" =>
75+
case "TINYINT" | "SMALLINT" =>
7576
s"$defaultValueType(${defaultValueRow.get(0)})"
77+
case "FLOAT" =>
78+
if (Float.NaN.equals(defaultValueRow.get(0))) {
79+
floatNaNLiteral
80+
} else {
81+
s"$defaultValueType(${defaultValueRow.get(0)})"
82+
}
83+
case "DOUBLE" =>
84+
if (Double.NaN.equals(defaultValueRow.get(0))) {
85+
doubleNaNLiteral
86+
} else {
87+
s"$defaultValueType(${defaultValueRow.get(0)})"
88+
}
7689
case "DATE" => s"$defaultValueType('${defaultValueRow.get(0)}')"
7790
case "STRING" => s"'${defaultValueRow.get(0)}'"
7891
case "TIMESTAMP" | "TIMESTAMP_NTZ" => s"TIMESTAMP '${defaultValueRow.get(0)}'"
@@ -101,7 +114,7 @@ class CometFuzzTestSuite extends CometFuzzTestBase {
101114
.asInstanceOf[Array[Byte]]
102115
.sameElements(spark.sql(sql).collect()(0).get(0).asInstanceOf[Array[Byte]]))
103116
} else {
104-
assert(defaultValueRow.get(0) == spark.sql(sql).collect()(0).get(0))
117+
assert(defaultValueRow.get(0).equals(spark.sql(sql).collect()(0).get(0)))
105118
}
106119
}
107120
}

0 commit comments

Comments
 (0)