Skip to content

Commit 8810d19

Browse files
authored
tests: FuzzDataGenerator instead of Parquet-specific generator (apache#2616)
1 parent 8a7c67a commit 8810d19

File tree

2 files changed

+259
-208
lines changed

2 files changed

+259
-208
lines changed
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.testing
21+
22+
import java.math.{BigDecimal, RoundingMode}
23+
import java.nio.charset.Charset
24+
import java.sql.Timestamp
25+
import java.text.SimpleDateFormat
26+
import java.time.{Instant, LocalDateTime, ZoneId}
27+
28+
import scala.collection.mutable.ListBuffer
29+
import scala.util.Random
30+
31+
import org.apache.commons.lang3.RandomStringUtils
32+
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
33+
import org.apache.spark.sql.types._
34+
35+
object FuzzDataGenerator {
36+
37+
/**
38+
* Date to use as base for generating temporal columns. Random integers will be added to or
39+
* subtracted from this value.
40+
*
41+
* Date was chosen to trigger generating a timestamp that's larger than a 64-bit nanosecond
42+
* timestamp can represent so that we test support for INT96 timestamps.
43+
*/
44+
val defaultBaseDate: Long =
45+
new SimpleDateFormat("YYYY-MM-DD hh:mm:ss").parse("3333-05-25 12:34:56").getTime
46+
47+
private val primitiveTypes = Seq(
48+
DataTypes.BooleanType,
49+
DataTypes.ByteType,
50+
DataTypes.ShortType,
51+
DataTypes.IntegerType,
52+
DataTypes.LongType,
53+
DataTypes.FloatType,
54+
DataTypes.DoubleType,
55+
DataTypes.createDecimalType(10, 2),
56+
DataTypes.createDecimalType(36, 18),
57+
DataTypes.DateType,
58+
DataTypes.TimestampType,
59+
DataTypes.TimestampNTZType,
60+
DataTypes.StringType,
61+
DataTypes.BinaryType)
62+
63+
private def filteredPrimitives(excludeTypes: Seq[DataType]) = {
64+
65+
primitiveTypes.filterNot { dataType =>
66+
excludeTypes.exists {
67+
case _: DecimalType =>
68+
// For DecimalType, match if the type is also a DecimalType (ignore precision/scale)
69+
dataType.isInstanceOf[DecimalType]
70+
case excludeType =>
71+
dataType == excludeType
72+
}
73+
}
74+
}
75+
76+
def generateDataFrame(
77+
r: Random,
78+
spark: SparkSession,
79+
numRows: Int,
80+
options: DataGenOptions): DataFrame = {
81+
82+
val filteredPrimitiveTypes = filteredPrimitives(options.excludeTypes)
83+
val dataTypes = ListBuffer[DataType]()
84+
dataTypes.appendAll(filteredPrimitiveTypes)
85+
86+
val arraysOfPrimitives = filteredPrimitiveTypes.map(DataTypes.createArrayType)
87+
88+
if (options.generateStruct) {
89+
dataTypes += StructType(filteredPrimitiveTypes.zipWithIndex.map(x =>
90+
StructField(s"c${x._2}", x._1, nullable = true)))
91+
92+
if (options.generateArray) {
93+
dataTypes += StructType(arraysOfPrimitives.zipWithIndex.map(x =>
94+
StructField(s"c${x._2}", x._1, nullable = true)))
95+
}
96+
}
97+
98+
if (options.generateMap) {
99+
dataTypes += MapType(DataTypes.IntegerType, DataTypes.StringType)
100+
}
101+
102+
if (options.generateArray) {
103+
dataTypes.appendAll(arraysOfPrimitives)
104+
105+
if (options.generateStruct) {
106+
dataTypes += DataTypes.createArrayType(
107+
StructType(filteredPrimitiveTypes.zipWithIndex.map(x =>
108+
StructField(s"c${x._2}", x._1, nullable = true))))
109+
}
110+
111+
if (options.generateMap) {
112+
dataTypes += DataTypes.createArrayType(
113+
MapType(DataTypes.IntegerType, DataTypes.StringType))
114+
}
115+
}
116+
117+
// generate schema using random data types
118+
val fields = dataTypes.zipWithIndex
119+
.map(i => StructField(s"c${i._2}", i._1, nullable = true))
120+
val schema = StructType(fields.toSeq)
121+
122+
// generate columnar data
123+
val cols: Seq[Seq[Any]] =
124+
schema.fields.map(f => generateColumn(r, f.dataType, numRows, options)).toSeq
125+
126+
// convert to rows
127+
val rows = Range(0, numRows).map(rowIndex => {
128+
Row.fromSeq(cols.map(_(rowIndex)))
129+
})
130+
131+
spark.createDataFrame(spark.sparkContext.parallelize(rows), schema)
132+
}
133+
134+
private def generateColumn(
135+
r: Random,
136+
dataType: DataType,
137+
numRows: Int,
138+
options: DataGenOptions): Seq[Any] = {
139+
dataType match {
140+
case ArrayType(elementType, _) =>
141+
val values = generateColumn(r, elementType, numRows, options)
142+
val list = ListBuffer[Any]()
143+
for (i <- 0 until numRows) {
144+
if (i % 10 == 0 && options.allowNull) {
145+
list += null
146+
} else {
147+
list += Range(0, r.nextInt(5)).map(j => values((i + j) % values.length)).toArray
148+
}
149+
}
150+
list.toSeq
151+
case StructType(fields) =>
152+
val values = fields.map(f => generateColumn(r, f.dataType, numRows, options))
153+
Range(0, numRows).map(i => Row(values.indices.map(j => values(j)(i)): _*))
154+
case MapType(keyType, valueType, _) =>
155+
val mapOptions = options.copy(allowNull = false)
156+
val k = generateColumn(r, keyType, numRows, mapOptions)
157+
val v = generateColumn(r, valueType, numRows, mapOptions)
158+
k.zip(v).map(x => Map(x._1 -> x._2))
159+
case DataTypes.BooleanType =>
160+
generateColumn(r, DataTypes.LongType, numRows, options)
161+
.map(_.asInstanceOf[Long].toShort)
162+
.map(s => s % 2 == 0)
163+
case DataTypes.ByteType =>
164+
generateColumn(r, DataTypes.LongType, numRows, options)
165+
.map(_.asInstanceOf[Long].toByte)
166+
case DataTypes.ShortType =>
167+
generateColumn(r, DataTypes.LongType, numRows, options)
168+
.map(_.asInstanceOf[Long].toShort)
169+
case DataTypes.IntegerType =>
170+
generateColumn(r, DataTypes.LongType, numRows, options)
171+
.map(_.asInstanceOf[Long].toInt)
172+
case DataTypes.LongType =>
173+
Range(0, numRows).map(_ => {
174+
r.nextInt(50) match {
175+
case 0 if options.allowNull => null
176+
case 1 => 0L
177+
case 2 => Byte.MinValue.toLong
178+
case 3 => Byte.MaxValue.toLong
179+
case 4 => Short.MinValue.toLong
180+
case 5 => Short.MaxValue.toLong
181+
case 6 => Int.MinValue.toLong
182+
case 7 => Int.MaxValue.toLong
183+
case 8 => Long.MinValue
184+
case 9 => Long.MaxValue
185+
case _ => r.nextLong()
186+
}
187+
})
188+
case DataTypes.FloatType =>
189+
Range(0, numRows).map(_ => {
190+
r.nextInt(20) match {
191+
case 0 if options.allowNull => null
192+
case 1 => Float.NegativeInfinity
193+
case 2 => Float.PositiveInfinity
194+
case 3 => Float.MinValue
195+
case 4 => Float.MaxValue
196+
case 5 => 0.0f
197+
case 6 if options.generateNegativeZero => -0.0f
198+
case _ => r.nextFloat()
199+
}
200+
})
201+
case DataTypes.DoubleType =>
202+
Range(0, numRows).map(_ => {
203+
r.nextInt(20) match {
204+
case 0 if options.allowNull => null
205+
case 1 => Double.NegativeInfinity
206+
case 2 => Double.PositiveInfinity
207+
case 3 => Double.MinValue
208+
case 4 => Double.MaxValue
209+
case 5 => 0.0
210+
case 6 if options.generateNegativeZero => -0.0
211+
case _ => r.nextDouble()
212+
}
213+
})
214+
case dt: DecimalType =>
215+
Range(0, numRows).map(_ =>
216+
new BigDecimal(r.nextDouble()).setScale(dt.scale, RoundingMode.HALF_UP))
217+
case DataTypes.StringType =>
218+
Range(0, numRows).map(_ => {
219+
r.nextInt(10) match {
220+
case 0 if options.allowNull => null
221+
case 1 => r.nextInt().toByte.toString
222+
case 2 => r.nextLong().toString
223+
case 3 => r.nextDouble().toString
224+
case 4 => RandomStringUtils.randomAlphabetic(8)
225+
case _ => r.nextString(8)
226+
}
227+
})
228+
case DataTypes.BinaryType =>
229+
generateColumn(r, DataTypes.StringType, numRows, options)
230+
.map {
231+
case x: String =>
232+
x.getBytes(Charset.defaultCharset())
233+
case _ =>
234+
null
235+
}
236+
case DataTypes.DateType =>
237+
Range(0, numRows).map(_ => new java.sql.Date(options.baseDate + r.nextInt()))
238+
case DataTypes.TimestampType =>
239+
Range(0, numRows).map(_ => new Timestamp(options.baseDate + r.nextInt()))
240+
case DataTypes.TimestampNTZType =>
241+
Range(0, numRows).map(_ =>
242+
LocalDateTime.ofInstant(
243+
Instant.ofEpochMilli(options.baseDate + r.nextInt()),
244+
ZoneId.systemDefault()))
245+
case _ => throw new IllegalStateException(s"Cannot generate data for $dataType yet")
246+
}
247+
}
248+
}
249+
250+
case class DataGenOptions(
251+
allowNull: Boolean = true,
252+
generateNegativeZero: Boolean = true,
253+
baseDate: Long = FuzzDataGenerator.defaultBaseDate,
254+
generateArray: Boolean = false,
255+
generateStruct: Boolean = false,
256+
generateMap: Boolean = false,
257+
excludeTypes: Seq[DataType] = Seq.empty)

0 commit comments

Comments
 (0)