|
1 | 1 | from datetime import timedelta, datetime |
2 | 2 |
|
3 | 3 | import pytest |
4 | | -from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType, DateType |
| 4 | +from pyspark.sql.types import ( |
| 5 | + StructType, StructField, IntegerType, StringType, FloatType, DateType, DecimalType, DoubleType, ByteType, |
| 6 | + ShortType, LongType |
| 7 | +) |
| 8 | + |
5 | 9 |
|
6 | 10 | import dbldatagen as dg |
7 | 11 | from dbldatagen import DataGenerator |
@@ -403,6 +407,28 @@ def test_basic_prefix(self): |
403 | 407 | rowCount = formattedDF.count() |
404 | 408 | assert rowCount == 1000 |
405 | 409 |
|
| 410 | + def test_missing_range_values(self): |
| 411 | + column_types = [FloatType(), DoubleType(), ByteType(), ShortType(), IntegerType(), LongType()] |
| 412 | + for column_type in column_types: |
| 413 | + range_no_min = NRange(maxValue=1.0) |
| 414 | + range_no_max = NRange(minValue=0.0) |
| 415 | + range_no_min.adjustForColumnDatatype(column_type) |
| 416 | + assert range_no_min.min == NRange._getNumericDataTypeRange(column_type)[0] |
| 417 | + assert range_no_min.step == 1 |
| 418 | + range_no_max.adjustForColumnDatatype(column_type) |
| 419 | + assert range_no_max.max == NRange._getNumericDataTypeRange(column_type)[1] |
| 420 | + assert range_no_max.step == 1 |
| 421 | + |
| 422 | + def test_range_with_until(self): |
| 423 | + range_until = NRange(step=2, until=100) |
| 424 | + range_until.adjustForColumnDatatype(IntegerType()) |
| 425 | + assert range_until.minValue == 0 |
| 426 | + assert range_until.maxValue == 101 |
| 427 | + |
| 428 | + def test_empty_range(self): |
| 429 | + empty_range = NRange() |
| 430 | + assert empty_range.isEmpty() |
| 431 | + |
406 | 432 | def test_reversed_ranges(self): |
407 | 433 | testDataSpec = (dg.DataGenerator(sparkSession=spark, name="ranged_data", rows=100000, |
408 | 434 | partitions=4) |
|
0 commit comments