Skip to content

Commit 99ba358

Browse files
committed
Add tests for better coverage of NRange
1 parent 5c62dfd commit 99ba358

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

tests/test_options.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,25 @@ def test_random2(self):
218218
colSpec3 = ds.getColumnSpec("code3")
219219
assert colSpec3.random is True
220220

221+
def test_random3(self):
222+
# will have implied column `id` for ordinal of row
223+
ds = (
224+
dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=500, partitions=1, random=True)
225+
.withIdOutput()
226+
.withColumn("val1", "decimal(5,2)", maxValue=20.0, step=0.01, random=True)
227+
.withColumn("val2", "float", maxValue=20.0, random=True)
228+
.withColumn("val3", "double", maxValue=20.0, random=True)
229+
.withColumn("val4", "byte", maxValue=15, random=True)
230+
.withColumn("val5", "short", maxValue=31, random=True)
231+
.withColumn("val6", "integer", maxValue=63, random=True)
232+
.withColumn("val7", "long", maxValue=127, random=True)
233+
)
234+
235+
df = ds.build()
236+
cols = ["val1", "val2", "val3", "val4", "val5", "val6", "val7"]
237+
for col in cols:
238+
assert df.collect() != df.orderBy(col).collect(), f"Random values were not generated for {col}"
239+
221240
def test_random_multiple_columns(self):
222241
# will have implied column `id` for ordinal of row
223242
ds = (

tests/test_quick_tests.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from datetime import timedelta, datetime
22

33
import pytest
4-
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType, DateType
4+
from pyspark.sql.types import (
5+
StructType, StructField, IntegerType, StringType, FloatType, DateType, DecimalType, DoubleType, ByteType,
6+
ShortType, LongType
7+
)
8+
59

610
import dbldatagen as dg
711
from dbldatagen import DataGenerator
@@ -403,6 +407,28 @@ def test_basic_prefix(self):
403407
rowCount = formattedDF.count()
404408
assert rowCount == 1000
405409

410+
def test_missing_range_values(self):
411+
column_types = [FloatType(), DoubleType(), ByteType(), ShortType(), IntegerType(), LongType()]
412+
for column_type in column_types:
413+
range_no_min = NRange(maxValue=1.0)
414+
range_no_max = NRange(minValue=0.0)
415+
range_no_min.adjustForColumnDatatype(column_type)
416+
assert range_no_min.min == NRange._getNumericDataTypeRange(column_type)[0]
417+
assert range_no_min.step == 1
418+
range_no_max.adjustForColumnDatatype(column_type)
419+
assert range_no_max.max == NRange._getNumericDataTypeRange(column_type)[1]
420+
assert range_no_max.step == 1
421+
422+
def test_range_with_until(self):
423+
range_until = NRange(step=2, until=100)
424+
range_until.adjustForColumnDatatype(IntegerType())
425+
assert range_until.minValue == 0
426+
assert range_until.maxValue == 101
427+
428+
def test_empty_range(self):
429+
empty_range = NRange()
430+
assert empty_range.isEmpty()
431+
406432
def test_reversed_ranges(self):
407433
testDataSpec = (dg.DataGenerator(sparkSession=spark, name="ranged_data", rows=100000,
408434
partitions=4)

0 commit comments

Comments
 (0)