From 475c1e89c03eae5178180e6ca36f8bad1f67710b Mon Sep 17 00:00:00 2001 From: Greg Hansen Date: Mon, 8 Dec 2025 14:34:03 -0500 Subject: [PATCH 1/8] Format modules --- dbldatagen/__init__.py | 81 +- dbldatagen/_version.py | 2 +- dbldatagen/column_generation_spec.py | 416 +++++---- dbldatagen/column_spec_options.py | 118 ++- dbldatagen/config.py | 1 + dbldatagen/constraints/__init__.py | 19 +- dbldatagen/constraints/chained_relation.py | 48 +- dbldatagen/constraints/constraint.py | 176 ++-- .../constraints/literal_range_constraint.py | 37 +- .../literal_relation_constraint.py | 38 +- dbldatagen/constraints/negative_values.py | 44 +- dbldatagen/constraints/positive_values.py | 45 +- .../constraints/ranged_values_constraint.py | 43 +- dbldatagen/constraints/sql_expr.py | 38 +- dbldatagen/constraints/unique_combinations.py | 91 +- dbldatagen/data_analyzer.py | 51 +- dbldatagen/data_generator.py | 256 ++--- dbldatagen/datagen_types.py | 13 + dbldatagen/datarange.py | 61 +- dbldatagen/datasets/__init__.py | 40 +- dbldatagen/datasets/basic_geometries.py | 110 ++- .../datasets/basic_process_historian.py | 69 +- dbldatagen/datasets/basic_stock_ticker.py | 114 ++- dbldatagen/datasets/basic_telematics.py | 66 +- dbldatagen/datasets/basic_user.py | 37 +- dbldatagen/datasets/benchmark_groupby.py | 69 +- dbldatagen/datasets/dataset_provider.py | 167 ++-- .../multi_table_sales_order_provider.py | 411 +++++--- .../multi_table_telephony_provider.py | 424 ++++++--- dbldatagen/datasets_object.py | 76 +- dbldatagen/daterange.py | 243 +++-- dbldatagen/distributions/__init__.py | 6 +- dbldatagen/distributions/beta.py | 80 +- dbldatagen/distributions/data_distribution.py | 71 +- .../distributions/exponential_distribution.py | 83 +- dbldatagen/distributions/gamma.py | 84 +- .../distributions/normal_distribution.py | 70 +- dbldatagen/multi_table_builder.py | 276 ++++++ dbldatagen/nrange.py | 287 ++++-- dbldatagen/relation.py | 33 + dbldatagen/schema_parser.py | 88 +- dbldatagen/serialization.py | 25 +- dbldatagen/spark_singleton.py | 13 +- dbldatagen/text_generator_plugins.py | 15 +- dbldatagen/text_generators.py | 335 +++++-- dbldatagen/utils.py | 58 +- docs/source/conf.py | 11 +- docs/utils/mk_quick_index.py | 219 ++--- docs/utils/mk_requirements.py | 10 +- pyproject.toml | 41 +- tests/__init__.py | 34 +- tests/test_basic_test.py | 362 +++---- tests/test_build_planning.py | 550 +++++------ tests/test_columnGenerationSpec.py | 16 +- tests/test_complex_columns.py | 881 +++++++++++------- tests/test_constraints.py | 291 +++--- tests/test_data_generation_plugins.py | 71 +- tests/test_dependent_data.py | 188 ++-- tests/test_distributions.py | 245 +++-- tests/test_generation_from_data.py | 29 +- tests/test_html_utils.py | 25 +- tests/test_iltext_generation.py | 212 +++-- tests/test_large_schema.py | 317 +++---- tests/test_logging.py | 23 +- tests/test_multi_table.py | 146 +++ tests/test_options.py | 73 +- tests/test_output.py | 12 +- tests/test_pandas_integration.py | 50 +- tests/test_quick_tests.py | 800 ++++++++-------- tests/test_ranged_values_and_dates.py | 689 +++++++------- tests/test_repeatable_data.py | 233 +++-- tests/test_schema_parser.py | 205 ++-- tests/test_scripting.py | 51 +- tests/test_serialization.py | 499 ++++++---- tests/test_serverless.py | 30 +- tests/test_shared_env.py | 1 + tests/test_standard_dataset_providers.py | 810 +++++++++++----- tests/test_standard_datasets.py | 139 +-- tests/test_streaming.py | 123 +-- tests/test_text_generation.py | 483 +++++----- tests/test_text_generator_basic.py | 122 ++- tests/test_text_templates.py | 329 +++---- tests/test_topological_sort.py | 4 +- tests/test_types.py | 208 +++-- tests/test_utils.py | 213 +++-- tests/test_weights.py | 270 +++--- tutorial/1-Introduction.py | 99 +- tutorial/2-Basics.py | 414 ++++---- tutorial/3-ChangeDataCapture-example.py | 136 +-- tutorial/4-Generating-multi-table-data.py | 348 ++++--- 90 files changed, 9012 insertions(+), 5928 deletions(-) create mode 100644 dbldatagen/datagen_types.py create mode 100644 dbldatagen/multi_table_builder.py create mode 100644 dbldatagen/relation.py create mode 100644 tests/test_multi_table.py diff --git a/dbldatagen/__init__.py b/dbldatagen/__init__.py index 3a00ce71..76aeb5c9 100644 --- a/dbldatagen/__init__.py +++ b/dbldatagen/__init__.py @@ -24,18 +24,44 @@ """ from .data_generator import DataGenerator -from .datagen_constants import DEFAULT_RANDOM_SEED, RANDOM_SEED_RANDOM, RANDOM_SEED_FIXED, \ - RANDOM_SEED_HASH_FIELD_NAME, MIN_PYTHON_VERSION, MIN_SPARK_VERSION, \ - INFER_DATATYPE, SPARK_DEFAULT_PARALLELISM -from .utils import ensure, topologicalSort, mkBoundsList, coalesce_values, \ - deprecated, parse_time_interval, DataGenError, split_list_matching_condition, strip_margins, \ - json_value_from_path, system_time_millis +from .datagen_constants import ( + DEFAULT_RANDOM_SEED, + RANDOM_SEED_RANDOM, + RANDOM_SEED_FIXED, + RANDOM_SEED_HASH_FIELD_NAME, + MIN_PYTHON_VERSION, + MIN_SPARK_VERSION, + INFER_DATATYPE, + SPARK_DEFAULT_PARALLELISM, +) +from .utils import ( + ensure, + topologicalSort, + mkBoundsList, + coalesce_values, + deprecated, + parse_time_interval, + DataGenError, + split_list_matching_condition, + strip_margins, + json_value_from_path, + system_time_millis, +) from ._version import __version__ from .column_generation_spec import ColumnGenerationSpec from .column_spec_options import ColumnSpecOptions -from .constraints import Constraint, ChainedRelation, LiteralRange, LiteralRelation, NegativeValues, PositiveValues, \ - RangedValues, SqlExpr, UniqueCombinations +from .constraints import ( + Constraint, + ChainedRelation, + LiteralRange, + LiteralRelation, + NegativeValues, + PositiveValues, + RangedValues, + SqlExpr, + UniqueCombinations, +) from .data_analyzer import DataAnalyzer from .schema_parser import SchemaParser from .daterange import DateRange @@ -48,24 +74,45 @@ from .html_utils import HtmlUtils from .datasets_object import Datasets from .config import OutputDataset +from .multi_table_builder import MultiTableBuilder +from .relation import ForeignKeyRelation +from .datagen_types import ColumnLike -__all__ = ["data_generator", "data_analyzer", "schema_parser", "daterange", "nrange", - "column_generation_spec", "utils", "function_builder", - "spark_singleton", "text_generators", "datarange", "datagen_constants", - "text_generator_plugins", "html_utils", "datasets_object", "constraints", "config" - ] +__all__ = [ + "data_generator", + "data_analyzer", + "schema_parser", + "daterange", + "nrange", + "column_generation_spec", + "utils", + "function_builder", + "spark_singleton", + "text_generators", + "datarange", + "datagen_constants", + "text_generator_plugins", + "html_utils", + "datasets_object", + "constraints", + "config", + "multi_table_builder", + "relation", + "datagen_types", +] def python_version_check(python_version_expected): """Check against Python version - Allows minimum version to be passed in to facilitate unit testing + Allows minimum version to be passed in to facilitate unit testing - :param python_version_expected: = minimum version of python to support as tuple e.g (3,6) - :return: True if passed + :param python_version_expected: = minimum version of python to support as tuple e.g (3,6) + :return: True if passed - """ + """ import sys + return sys.version_info >= python_version_expected diff --git a/dbldatagen/_version.py b/dbldatagen/_version.py index b4416ca0..420d289a 100644 --- a/dbldatagen/_version.py +++ b/dbldatagen/_version.py @@ -22,7 +22,7 @@ def get_version(version: str) -> VersionInfo: - """ Get version info object for library. + """Get version info object for library. :param version: version string to parse for version information diff --git a/dbldatagen/column_generation_spec.py b/dbldatagen/column_generation_spec.py index 62482b95..4c6ce9f4 100644 --- a/dbldatagen/column_generation_spec.py +++ b/dbldatagen/column_generation_spec.py @@ -10,17 +10,35 @@ import logging from pyspark.sql.functions import col, pandas_udf -from pyspark.sql.functions import lit, concat, rand, round as sql_round, array, expr, when, udf, \ - format_string +from pyspark.sql.functions import lit, concat, rand, round as sql_round, array, expr, when, udf, format_string import pyspark.sql.functions as F -from pyspark.sql.types import FloatType, IntegerType, StringType, DoubleType, BooleanType, \ - TimestampType, DataType, DateType, ArrayType, MapType, StructType +from pyspark.sql.types import ( + FloatType, + IntegerType, + StringType, + DoubleType, + BooleanType, + TimestampType, + DataType, + DateType, + ArrayType, + MapType, + StructType, +) from .column_spec_options import ColumnSpecOptions -from .datagen_constants import RANDOM_SEED_FIXED, RANDOM_SEED_HASH_FIELD_NAME, RANDOM_SEED_RANDOM, \ - DEFAULT_SEED_COLUMN, OPTION_RANDOM, OPTION_RANDOM_SEED, OPTION_RANDOM_SEED_METHOD, INFER_DATATYPE +from .datagen_constants import ( + RANDOM_SEED_FIXED, + RANDOM_SEED_HASH_FIELD_NAME, + RANDOM_SEED_RANDOM, + DEFAULT_SEED_COLUMN, + OPTION_RANDOM, + OPTION_RANDOM_SEED, + OPTION_RANDOM_SEED_METHOD, + INFER_DATATYPE, +) from .daterange import DateRange from .distributions import Normal, DataDistribution @@ -35,14 +53,16 @@ RAW_VALUES_COMPUTE_METHOD = "raw_values" AUTO_COMPUTE_METHOD = "auto" EXPR_OPTION = "expr" -COMPUTE_METHOD_VALID_VALUES = [HASH_COMPUTE_METHOD, - AUTO_COMPUTE_METHOD, - VALUES_COMPUTE_METHOD, - RAW_VALUES_COMPUTE_METHOD] +COMPUTE_METHOD_VALID_VALUES = [ + HASH_COMPUTE_METHOD, + AUTO_COMPUTE_METHOD, + VALUES_COMPUTE_METHOD, + RAW_VALUES_COMPUTE_METHOD, +] class ColumnGenerationSpec(SerializableToDict): - """ Column generation spec object - specifies how column is to be generated + """Column generation spec object - specifies how column is to be generated Each column to be output will have a corresponding ColumnGenerationSpec object. This is added explicitly using the DataGenerators `withColumnSpec` or `withColumn` methods @@ -85,10 +105,7 @@ class ColumnGenerationSpec(SerializableToDict): datatype: DataType #: maxValue values for each column type, only if where value is intentionally restricted - _max_type_range = { - 'byte': 256, - 'short': 65536 - } + _max_type_range = {'byte': 256, 'short': 65536} _ARRAY_STRUCT_TYPE = "array" @@ -97,11 +114,28 @@ class ColumnGenerationSpec(SerializableToDict): # restrict spurious messages from java gateway logging.getLogger("py4j").setLevel(logging.WARNING) - def __init__(self, name, colType=None, *, minValue=0, maxValue=None, step=1, prefix='', random=False, - distribution=None, baseColumn=None, randomSeed=None, randomSeedMethod=None, - implicit=False, omit=False, nullable=True, debug=False, verbose=False, - seedColumnName=DEFAULT_SEED_COLUMN, - **kwargs): + def __init__( + self, + name, + colType=None, + *, + minValue=0, + maxValue=None, + step=1, + prefix='', + random=False, + distribution=None, + baseColumn=None, + randomSeed=None, + randomSeedMethod=None, + implicit=False, + omit=False, + nullable=True, + debug=False, + verbose=False, + seedColumnName=DEFAULT_SEED_COLUMN, + **kwargs, + ): # set up logging self.verbose = verbose @@ -138,13 +172,22 @@ def __init__(self, name, colType=None, *, minValue=0, maxValue=None, step=1, pre # to allow for open ended extension of many column attributes, we use a few specific # parameters and pass the rest as keyword arguments - supplied_options = {'name': name, 'minValue': minValue, 'type': colType, - 'maxValue': maxValue, 'step': step, - 'prefix': prefix, 'baseColumn': baseColumn, - OPTION_RANDOM: random, 'distribution': distribution, - OPTION_RANDOM_SEED_METHOD: randomSeedMethod, OPTION_RANDOM_SEED: randomSeed, - 'omit': omit, 'nullable': nullable, 'implicit': implicit - } + supplied_options = { + 'name': name, + 'minValue': minValue, + 'type': colType, + 'maxValue': maxValue, + 'step': step, + 'prefix': prefix, + 'baseColumn': baseColumn, + OPTION_RANDOM: random, + 'distribution': distribution, + OPTION_RANDOM_SEED_METHOD: randomSeedMethod, + OPTION_RANDOM_SEED: randomSeed, + 'omit': omit, + 'nullable': nullable, + 'implicit': implicit, + } supplied_options.update(kwargs) @@ -192,8 +235,10 @@ def __init__(self, name, colType=None, *, minValue=0, maxValue=None, step=1, pre # use of a random seed method will ensure that we have repeatability of data generation assert randomSeed is None or type(randomSeed) in [int, float], "seed should be None or numeric" - assert randomSeedMethod is None or randomSeedMethod in [RANDOM_SEED_FIXED, RANDOM_SEED_HASH_FIELD_NAME], \ - f"`randomSeedMethod` should be none or `{RANDOM_SEED_FIXED}` or `{RANDOM_SEED_HASH_FIELD_NAME}`" + assert randomSeedMethod is None or randomSeedMethod in [ + RANDOM_SEED_FIXED, + RANDOM_SEED_HASH_FIELD_NAME, + ], f"`randomSeedMethod` should be none or `{RANDOM_SEED_FIXED}` or `{RANDOM_SEED_HASH_FIELD_NAME}`" self._randomSeedMethod = self[OPTION_RANDOM_SEED_METHOD] self.random = self[OPTION_RANDOM] @@ -213,8 +258,13 @@ def __init__(self, name, colType=None, *, minValue=0, maxValue=None, step=1, pre # value of `base_column_type` must be `None`,"values", "raw_values", "auto", or "hash" # this is the method of computing current column value from base column, not the data type of the base column - allowed_compute_methods = [AUTO_COMPUTE_METHOD, VALUES_COMPUTE_METHOD, HASH_COMPUTE_METHOD, - RAW_VALUES_COMPUTE_METHOD, None] + allowed_compute_methods = [ + AUTO_COMPUTE_METHOD, + VALUES_COMPUTE_METHOD, + HASH_COMPUTE_METHOD, + RAW_VALUES_COMPUTE_METHOD, + None, + ] self._csOptions.checkOptionValues("baseColumnType", allowed_compute_methods) self._baseColumnComputeMethod = self['baseColumnType'] @@ -267,58 +317,73 @@ def __init__(self, name, colType=None, *, minValue=0, maxValue=None, step=1, pre # handle default method of computing the base column value # if we have text manipulation, use 'values' as default for format but 'hash' as default if # its a column with multiple values - if self._baseColumnComputeMethod in [None, AUTO_COMPUTE_METHOD] \ - and (self.textGenerator is not None or self['format'] is not None - or self['prefix'] is not None or self['suffix'] is not None): + if self._baseColumnComputeMethod in [None, AUTO_COMPUTE_METHOD] and ( + self.textGenerator is not None + or self['format'] is not None + or self['prefix'] is not None + or self['suffix'] is not None + ): if self.values is not None: - self.logger.info("""Column [%s] has no `base_column_type` attribute and uses discrete values + self.logger.info( + """Column [%s] has no `base_column_type` attribute and uses discrete values => Assuming `hash` for attribute `base_column_type`. => Use explicit value for `base_column_type` if alternate interpretation needed - """, self.name) + """, + self.name, + ) self._baseColumnComputeMethod = HASH_COMPUTE_METHOD else: - self.logger.info("""Column [%s] has no `base_column_type` attribute specified for formatted text + self.logger.info( + """Column [%s] has no `base_column_type` attribute specified for formatted text => Assuming `values` for attribute `base_column_type`. => Use explicit value for `base_column_type` if alternate interpretation needed - """, self.name) + """, + self.name, + ) self._baseColumnComputeMethod = VALUES_COMPUTE_METHOD # adjust the range by merging type and range information - self._dataRange = self._computeAdjustedRangeForColumn(colType=colType, - c_min=c_min, c_max=c_max, c_step=c_step, - c_begin=c_begin, c_end=c_end, - c_interval=c_interval, - c_unique=unique_values, c_range=data_range) + self._dataRange = self._computeAdjustedRangeForColumn( + colType=colType, + c_min=c_min, + c_max=c_max, + c_step=c_step, + c_begin=c_begin, + c_end=c_end, + c_interval=c_interval, + c_unique=unique_values, + c_range=data_range, + ) if self.distribution is not None: - ensure((self._dataRange is not None and self._dataRange.isFullyPopulated()) - or - self.values is not None, - """When using an explicit distribution, provide a fully populated range or a set of values""") + ensure( + (self._dataRange is not None and self._dataRange.isFullyPopulated()) or self.values is not None, + """When using an explicit distribution, provide a fully populated range or a set of values""", + ) # set up the temporary columns needed for data generation self._setupTemporaryColumns() def _toInitializationDict(self): - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + """Converts an object to a Python dictionary. Keys represent the object's + constructor arguments. + :return: Python dictionary representation of the object """ _options = self._csOptions.options.copy() _options["colName"] = _options.pop("name", self.name) _options["colType"] = _options.pop("type", self.datatype).simpleString() _options["kind"] = self.__class__.__name__ return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } def _temporaryRename(self, tmpName): - """ Create enter / exit object to support temporary renaming of column spec + """Create enter / exit object to support temporary renaming of column spec This is to support the functionality: @@ -343,14 +408,14 @@ def _temporaryRename(self, tmpName): class RenameEnterExit: def __init__(self, columnSpec, newName): - """ Save column spec and old name to support enter / exit semantics """ + """Save column spec and old name to support enter / exit semantics""" self._cs = columnSpec self._oldName = columnSpec.name self._newName = newName self._randomSeed = columnSpec._randomSeed def __enter__(self): - """ Return the inner column spec object """ + """Return the inner column spec object""" self._cs.name = self._newName if self._cs._randomSeedMethod == RANDOM_SEED_HASH_FIELD_NAME: @@ -384,7 +449,7 @@ def __exit__(self, exc_type, exc_value, tb): @property def specOptions(self): - """ get column spec options for spec + """get column spec options for spec .. note:: This is intended for testing use only. @@ -416,28 +481,27 @@ def __deepcopy__(self, memo): @property def randomSeed(self): - """ get random seed for column spec""" + """get random seed for column spec""" return self._randomSeed @property def isRandom(self): - """ returns True if column will be randomly generated""" + """returns True if column will be randomly generated""" return self[OPTION_RANDOM] @property def textGenerator(self): - """ Get the text generator for the column spec""" + """Get the text generator for the column spec""" return self._textGenerator @property def inferDatatype(self): - """ If True indicates that datatype should be inferred to be result of computing SQL expression - """ + """If True indicates that datatype should be inferred to be result of computing SQL expression""" return self._inferDataType @property def baseColumns(self): - """ Return base columns as list of strings""" + """Return base columns as list of strings""" # if base column is string and contains multiple columns, split them # other build list of columns if needed @@ -449,7 +513,7 @@ def baseColumns(self): return [self.baseColumn] def _computeBasicDependencies(self): - """ get set of basic column dependencies. + """get set of basic column dependencies. These are used to compute the order of field evaluation @@ -461,18 +525,20 @@ def _computeBasicDependencies(self): return [self._seedColumnName] def setBaseColumnDatatypes(self, columnDatatypes): - """ Set the data types for the base columns + """Set the data types for the base columns :param column_datatypes: = list of data types for the base columns """ assert type(columnDatatypes) is list, " `column_datatypes` parameter must be list" - ensure(len(columnDatatypes) == len(self.baseColumns), - "number of base column datatypes must match number of base columns") + ensure( + len(columnDatatypes) == len(self.baseColumns), + "number of base column datatypes must match number of base columns", + ) self._baseColumnDatatypes = columnDatatypes.copy() def _setupTemporaryColumns(self): - """ Set up any temporary columns needed for test data generation. + """Set up any temporary columns needed for test data generation. For some types of test data, intermediate columns are used in the data generation process but dropped from the final output @@ -481,16 +547,19 @@ def _setupTemporaryColumns(self): # if its a weighted values column, then create temporary for it # not supported for feature / array columns for now min_num_columns, max_num_columns, struct_type = self._getMultiColumnDetails(validate=False) - ensure(max_num_columns is None or max_num_columns <= 1, - "weighted columns not supported for multi-column or multi-feature values") + ensure( + max_num_columns is None or max_num_columns <= 1, + "weighted columns not supported for multi-column or multi-feature values", + ) if self.random: temp_name = f"_rnd_{self.name}" self.dependencies.append(temp_name) desc = f"adding temporary column {temp_name} required by {self.name}" self._initialBuildPlan.append(desc) sql_random_generator = self._getUniformRandomSQLExpression(self.name) - self.temporaryColumns.append((temp_name, DoubleType(), {'expr': sql_random_generator, 'omit': True, - 'description': desc})) + self.temporaryColumns.append( + (temp_name, DoubleType(), {'expr': sql_random_generator, 'omit': True, 'description': desc}) + ) self._weightedBaseColumn = temp_name else: # create temporary expression mapping values to range of weights @@ -500,22 +569,35 @@ def _setupTemporaryColumns(self): self._initialBuildPlan.append(desc) # use a base expression based on mapping base column to size of data - sql_scaled_generator = self._getScaledIntSQLExpression(self.name, - scale=sum(self.weights), - base_columns=self.baseColumns, - base_datatypes=self._baseColumnDatatypes, - compute_method=self._baseColumnComputeMethod, - normalize=True) - - self.logger.debug("""building scaled sql expression : '%s' + sql_scaled_generator = self._getScaledIntSQLExpression( + self.name, + scale=sum(self.weights), + base_columns=self.baseColumns, + base_datatypes=self._baseColumnDatatypes, + compute_method=self._baseColumnComputeMethod, + normalize=True, + ) + + self.logger.debug( + """building scaled sql expression : '%s' with base column: %s, dependencies: %s""", - sql_scaled_generator, - self.baseColumn, - self.dependencies) - - self.temporaryColumns.append((temp_name, DoubleType(), {'expr': sql_scaled_generator, 'omit': True, - 'baseColumn': self.baseColumn, - 'description': desc})) + sql_scaled_generator, + self.baseColumn, + self.dependencies, + ) + + self.temporaryColumns.append( + ( + temp_name, + DoubleType(), + { + 'expr': sql_scaled_generator, + 'omit': True, + 'baseColumn': self.baseColumn, + 'description': desc, + }, + ) + ) self._weightedBaseColumn = temp_name def _setup_logger(self): @@ -531,20 +613,20 @@ def _setup_logger(self): else: self.logger.setLevel(logging.WARNING) - def _computeAdjustedRangeForColumn(self, colType, c_min, c_max, c_step, *, c_begin, c_end, c_interval, c_range, - c_unique): - """Determine adjusted range for data column - """ + def _computeAdjustedRangeForColumn( + self, colType, c_min, c_max, c_step, *, c_begin, c_end, c_interval, c_range, c_unique + ): + """Determine adjusted range for data column""" assert colType is not None, "`colType` must be non-None instance" if type(colType) is DateType or type(colType) is TimestampType: - return self._computeAdjustedDateTimeRangeForColumn(colType, c_begin, c_end, c_interval, - c_range=c_range, - c_unique=c_unique) + return self._computeAdjustedDateTimeRangeForColumn( + colType, c_begin, c_end, c_interval, c_range=c_range, c_unique=c_unique + ) else: - return self._computeAdjustedNumericRangeForColumn(colType, c_min, c_max, c_step, - c_range=c_range, - c_unique=c_unique) + return self._computeAdjustedNumericRangeForColumn( + colType, c_min, c_max, c_step, c_range=c_range, c_unique=c_unique + ) def _computeAdjustedNumericRangeForColumn(self, colType, c_min, c_max, c_step, *, c_range, c_unique): """Determine adjusted range for data column @@ -577,10 +659,12 @@ def _computeAdjustedNumericRangeForColumn(self, colType, c_min, c_max, c_step, * result = NRange(effective_min, unique_max, effective_step) if result.maxValue is not None and effective_max is not None and result.maxValue > effective_max: - self.logger.warning("Computed maxValue for column [%s] of %s is greater than specified maxValue %s", - self.name, - result.maxValue, - effective_max) + self.logger.warning( + "Computed maxValue for column [%s] of %s is greater than specified maxValue %s", + self.name, + result.maxValue, + effective_max, + ) elif c_range is not None: result = c_range elif c_range is None: @@ -596,8 +680,7 @@ def _computeAdjustedNumericRangeForColumn(self, colType, c_min, c_max, c_step, * return result def _computeAdjustedDateTimeRangeForColumn(self, colType, c_begin, c_end, c_interval, *, c_range, c_unique): - """Determine adjusted range for Date or Timestamp data column - """ + """Determine adjusted range for Date or Timestamp data column""" effective_begin, effective_end, effective_interval = None, None, None if c_range is not None and type(c_range) is DateRange: effective_begin = c_range.begin @@ -616,7 +699,7 @@ def _computeAdjustedDateTimeRangeForColumn(self, colType, c_begin, c_end, c_inte return result def _getUniformRandomExpression(self, col_name): - """ Get random expression accounting for seed method + """Get random expression accounting for seed method :returns: expression of ColDef form - i.e `lit`, `expr` etc @@ -632,7 +715,7 @@ def _getUniformRandomExpression(self, col_name): return rand() def _getRandomExpressionForDistribution(self, col_name, col_distribution): - """ Get random expression accounting for seed method + """Get random expression accounting for seed method :returns: expression of ColDef form - i.e `lit`, `expr` etc @@ -640,15 +723,16 @@ def _getRandomExpressionForDistribution(self, col_name, col_distribution): """ assert col_name is not None and len(col_name) > 0, "`col_name` must not be None and non empty" assert col_distribution is not None, "`col_distribution` must not be None" - assert isinstance(col_distribution, DataDistribution), \ - "`distribution` object must be an instance of data distribution" + assert isinstance( + col_distribution, DataDistribution + ), "`distribution` object must be an instance of data distribution" self.executionHistory.append(f".. random number generation via distribution `{col_distribution}`") return col_distribution.generateNormalizedDistributionSample() def _getUniformRandomSQLExpression(self, col_name): - """ Get random SQL expression accounting for seed method + """Get random SQL expression accounting for seed method :returns: expression as a SQL string """ @@ -662,9 +746,10 @@ def _getUniformRandomSQLExpression(self, col_name): else: return "rand()" - def _getScaledIntSQLExpression(self, col_name, scale, base_columns, *, base_datatypes=None, compute_method=None, - normalize=False): - """ Get scaled numeric expression + def _getScaledIntSQLExpression( + self, col_name, scale, base_columns, *, base_datatypes=None, compute_method=None, normalize=False + ): + """Get scaled numeric expression This will produce a scaled SQL expression from the base columns @@ -683,11 +768,12 @@ def _getScaledIntSQLExpression(self, col_name, scale, base_columns, *, base_data assert col_name is not None, "`col_name` must not be None" assert self.name is not None, "`self.name` must not be None" assert scale is not None, "`scale` must not be None" - assert (compute_method is None or - compute_method in COMPUTE_METHOD_VALID_VALUES), "`compute_method` must be valid value " - assert (base_columns is not None and - type(base_columns) is list - and len(base_columns) > 0), "Base columns must be a non-empty list" + assert ( + compute_method is None or compute_method in COMPUTE_METHOD_VALID_VALUES + ), "`compute_method` must be valid value " + assert ( + base_columns is not None and type(base_columns) is list and len(base_columns) > 0 + ), "Base columns must be a non-empty list" effective_compute_method = compute_method @@ -695,7 +781,8 @@ def _getScaledIntSQLExpression(self, col_name, scale, base_columns, *, base_data if len(base_columns) > 1: if compute_method == VALUES_COMPUTE_METHOD: self.logger.warning( - "For column generation with values and multiple base columns, data will be computed with `hash`") + "For column generation with values and multiple base columns, data will be computed with `hash`" + ) effective_compute_method = HASH_COMPUTE_METHOD if effective_compute_method is None or effective_compute_method is AUTO_COMPUTE_METHOD: @@ -716,11 +803,11 @@ def _getScaledIntSQLExpression(self, col_name, scale, base_columns, *, base_data @property def isWeightedValuesColumn(self): - """ check if column is a weighed values column """ + """check if column is a weighed values column""" return self['weights'] is not None and self.values is not None def getNames(self): - """ get column names as list of strings""" + """get column names as list of strings""" min_num_columns, max_num_columns, struct_type = self._getMultiColumnDetails(validate=False) if max_num_columns > 1 and struct_type is None: @@ -729,7 +816,7 @@ def getNames(self): return [self.name] def getNamesAndTypes(self): - """ get column names as list of tuples `(name, datatype)`""" + """get column names as list of tuples `(name, datatype)`""" min_num_columns, max_num_columns, struct_type = self._getMultiColumnDetails(validate=False) if max_num_columns > 1 and struct_type is None: @@ -738,18 +825,18 @@ def getNamesAndTypes(self): return [(self.name, self.datatype)] def keys(self): - """ Get the keys as list of strings """ + """Get the keys as list of strings""" assert self._csOptions is not None, "self._csOptions should be non-empty" return self._csOptions.keys() def __getitem__(self, key): - """ implement the built in dereference by key behavior """ + """implement the built in dereference by key behavior""" assert key is not None, "key should be non-empty" return self._csOptions.getOrElse(key, None) @property def isFieldOmitted(self): - """ check if this field should be omitted from the output + """check if this field should be omitted from the output If the field is omitted from the output, the field is available for use in expressions etc. but dropped from the final set of fields @@ -799,8 +886,7 @@ def step(self): @property def exprs(self): - """get the column generation `exprs` attribute used to generate values for this column. - """ + """get the column generation `exprs` attribute used to generate values for this column.""" return self['exprs'] @property @@ -867,7 +953,7 @@ def structType(self): return self['structType'] def getOrElse(self, key, default=None): - """ Get value for option key if it exists or else return default + """Get value for option key if it exists or else return default :param key: key name for option :param default: default value if option was not provided @@ -877,7 +963,7 @@ def getOrElse(self, key, default=None): return self._csOptions.getOrElse(key, default) def getPlanEntry(self): - """ Get execution plan entry for object + """Get execution plan entry for object :returns: String representation of plan entry """ @@ -893,6 +979,7 @@ def _makeWeightedColumnValuesExpression(self, values, weights, seed_column_name) :returns: Spark SQL expr """ from .function_builder import ColumnGeneratorBuilder + assert values is not None, "`values` expression must be supplied as list of values" assert weights is not None, "`weights` expression must be list of weights" assert len(values) == len(weights), "`weights` and `values` lists must be of equal length" @@ -901,7 +988,7 @@ def _makeWeightedColumnValuesExpression(self, values, weights, seed_column_name) return expr(expr_str).astype(self.datatype) def _isRealValuedColumn(self): - """ determine if column is real valued + """determine if column is real valued :returns: Boolean - True if condition is true """ @@ -910,7 +997,7 @@ def _isRealValuedColumn(self): return col_type_name in ['double', 'float', 'decimal'] def _isDecimalColumn(self): - """ determine if column is decimal column + """determine if column is decimal column :returns: Boolean - True if condition is true """ @@ -919,7 +1006,7 @@ def _isDecimalColumn(self): return col_type_name == 'decimal' def _isContinuousValuedColumn(self): - """ determine if column generates continuous values + """determine if column generates continuous values :returns: Boolean - True if condition is true """ @@ -928,7 +1015,7 @@ def _isContinuousValuedColumn(self): return is_continuous def _getSeedExpression(self, base_column): - """ Get seed expression for column generation + """Get seed expression for column generation This is used to generate the base value for every column if using a single base column, then simply use that, otherwise use either @@ -959,7 +1046,7 @@ def _isStringField(self): return type(self.datatype) is StringType def _computeRangedColumn(self, datarange, base_column, is_random): - """ compute a ranged column + """compute a ranged column maxValue is maxValue actual value @@ -985,8 +1072,11 @@ def _computeRangedColumn(self, datarange, base_column, is_random): modulo_factor = lit(crange + 1) # following expression is needed as spark sql modulo of negative number is negative modulo_exp = ((self._getSeedExpression(base_column) % modulo_factor) + modulo_factor) % modulo_factor - baseval = (modulo_exp * lit(datarange.step)) if not is_random else ( - sql_round(random_generator * lit(crange)) * lit(datarange.step)) + baseval = ( + (modulo_exp * lit(datarange.step)) + if not is_random + else (sql_round(random_generator * lit(crange)) * lit(datarange.step)) + ) if self._baseColumnComputeMethod == VALUES_COMPUTE_METHOD: new_def = self._adjustForMinValue(baseval, datarange) @@ -1008,7 +1098,7 @@ def _computeRangedColumn(self, datarange, base_column, is_random): return new_def def _adjustForMinValue(self, baseval, datarange, force=False): - """ Adjust for minimum value of data range + """Adjust for minimum value of data range :param baseval: base expression :param datarange: data range to conform to :param force: always adjust (possibly for implicit cast reasons) @@ -1022,11 +1112,11 @@ def _adjustForMinValue(self, baseval, datarange, force=False): return new_def def _makeSingleGenerationExpression(self, index=None, use_pandas_optimizations=True): - """ generate column data for a single column value via Spark SQL expression + """generate column data for a single column value via Spark SQL expression - :param index: for multi column generation, specifies index of column being generated - :param use_pandas_optimizations: if True, uses Pandas vectorized optimizations. Defaults to `True` - :returns: spark sql `column` or expression that can be used to generate a column + :param index: for multi column generation, specifies index of column being generated + :param use_pandas_optimizations: if True, uses Pandas vectorized optimizations. Defaults to `True` + :returns: spark sql `column` or expression that can be used to generate a column """ self.logger.debug("building column : %s", self.name) @@ -1080,16 +1170,16 @@ def _makeSingleGenerationExpression(self, index=None, use_pandas_optimizations=T new_def = expr("NULL") elif self._dataRange is not None and self._dataRange.isFullyPopulated(): self.executionHistory.append(f".. computing ranged value: {self._dataRange}") - new_def = self._computeRangedColumn(base_column=self.baseColumn, datarange=self._dataRange, - is_random=col_is_rand) + new_def = self._computeRangedColumn( + base_column=self.baseColumn, datarange=self._dataRange, is_random=col_is_rand + ) elif type(self.datatype) is DateType: # TODO: fixup for date generation # record execution history self.executionHistory.append(".. using random date expression") sql_random_generator = self._getUniformRandomSQLExpression(self.name) - new_def = expr(f"date_sub(current_date, rounding({sql_random_generator}*1024))").astype( - self.datatype) + new_def = expr(f"date_sub(current_date, rounding({sql_random_generator}*1024))").astype(self.datatype) else: if self._baseColumnComputeMethod == VALUES_COMPUTE_METHOD: self.executionHistory.append(".. using values compute expression for seed") @@ -1103,8 +1193,9 @@ def _makeSingleGenerationExpression(self, index=None, use_pandas_optimizations=T else: self.logger.info("Assuming a seeded base expression with minimum value for column %s", self.name) self.executionHistory.append(f".. seeding with minimum `{self._dataRange.minValue}`") - new_def = ((self._getSeedExpression(self.baseColumn) + lit(self._dataRange.minValue)) - .astype(self.datatype)) + new_def = (self._getSeedExpression(self.baseColumn) + lit(self._dataRange.minValue)).astype( + self.datatype + ) if self.values is not None: new_def = F.element_at(F.array([F.lit(x) for x in self.values]), new_def.astype(IntegerType()) + 1) @@ -1154,8 +1245,9 @@ def _applyPrefixSuffixExpressions(self, cprefix, csuffix, new_def): text_separator = self.text_separator if self.text_separator is not None else '_' if cprefix is not None and csuffix is not None: self.executionHistory.append(".. applying column prefix and suffix") - new_def = concat(lit(cprefix), lit(text_separator), new_def.astype(IntegerType()), lit(text_separator), - lit(csuffix)) + new_def = concat( + lit(cprefix), lit(text_separator), new_def.astype(IntegerType()), lit(text_separator), lit(csuffix) + ) elif cprefix is not None: self.executionHistory.append(".. applying column prefix") new_def = concat(lit(cprefix), lit(text_separator), new_def.astype(IntegerType())) @@ -1177,17 +1269,15 @@ def _applyTextGenerationExpression(self, new_def, use_pandas_optimizations): tg = self.textGenerator if use_pandas_optimizations: self.executionHistory.append(f".. text generation via pandas scalar udf `{tg}`") - u_value_from_generator = pandas_udf(tg.pandasGenerateText, - returnType=StringType()).asNondeterministic() + u_value_from_generator = pandas_udf(tg.pandasGenerateText, returnType=StringType()).asNondeterministic() else: self.executionHistory.append(f".. text generation via udf `{tg}`") - u_value_from_generator = udf(tg.classicGenerateText, - StringType()).asNondeterministic() + u_value_from_generator = udf(tg.classicGenerateText, StringType()).asNondeterministic() new_def = u_value_from_generator(new_def) return new_def def _applyFinalCastExpression(self, col_type, new_def): - """ Apply final cast expression for column data + """Apply final cast expression for column data :param col_type: final column type :param new_def: column definition being created @@ -1209,9 +1299,9 @@ def _applyFinalCastExpression(self, col_type, new_def): def _applyComputePercentNullsExpression(self, newDef, probabilityNulls): """Compute percentage nulls for column being generated - :param newDef: Column definition being created - :param probabilityNulls: Probability of nulls to be generated for particular column. Values can be 0.0 - 1.0 - :returns: new column definition with probability of nulls applied + :param newDef: Column definition being created + :param probabilityNulls: Probability of nulls to be generated for particular column. Values can be 0.0 - 1.0 + :returns: new column definition with probability of nulls applied """ assert self.nullable, f"Column `{self.name}` must be nullable for `percent_nulls` option" self.executionHistory.append(".. applying null generator - `when rnd > prob then value - else null`") @@ -1225,9 +1315,9 @@ def _applyComputePercentNullsExpression(self, newDef, probabilityNulls): return newDef def _computeImpliedRangeIfNeeded(self, col_type): - """ Compute implied range if necessary - :param col_type" Column type - :returns: nothing + """Compute implied range if necessary + :param col_type" Column type + :returns: nothing """ # check for implied ranges if self.values is not None: @@ -1237,7 +1327,7 @@ def _computeImpliedRangeIfNeeded(self, col_type): self.executionHistory.append(f".. using adjusted effective range: {self._dataRange}") def _getMultiColumnDetails(self, validate): - """ Determine min and max number of columns to generate along with `structType` for columns + """Determine min and max number of columns to generate along with `structType` for columns with multiple columns / features :param validate: If true, raises ValueError if there are bad option entries @@ -1268,16 +1358,14 @@ def _getMultiColumnDetails(self, validate): min_num_columns, max_num_columns = 1, 1 if validate and (min_num_columns != max_num_columns) and (struct_type != self._ARRAY_STRUCT_TYPE): - self.logger.warning( - f"Varying number of features / columns specified for non-array column [{self.name}]") - self.logger.warning( - f"Lower bound for number of features / columns ignored for [{self.name}]") + self.logger.warning(f"Varying number of features / columns specified for non-array column [{self.name}]") + self.logger.warning(f"Lower bound for number of features / columns ignored for [{self.name}]") min_num_columns = max_num_columns return min_num_columns, max_num_columns, struct_type def makeGenerationExpressions(self): - """ Generate structured column if multiple columns or features are specified + """Generate structured column if multiple columns or features are specified if there are multiple columns / features specified using a single definition, it will generate a set of columns conforming to the same definition, diff --git a/dbldatagen/column_spec_options.py b/dbldatagen/column_spec_options.py index dcfb277b..ce950733 100644 --- a/dbldatagen/column_spec_options.py +++ b/dbldatagen/column_spec_options.py @@ -154,34 +154,54 @@ class ColumnSpecOptions(object): 'random_seed_method': 'randomSeedMethod', 'random_seed': 'randomSeed', 'text_separator': 'textSeparator', - } #: the set of attributes that are permitted for any call to data generator `withColumn` or `withColumnSpec` - _ALLOWED_PROPERTIES = {'name', 'type', 'minValue', 'maxValue', 'step', - 'prefix', 'random', 'distribution', - 'range', 'baseColumn', 'baseColumnType', 'values', - 'numColumns', 'numFeatures', 'structType', - 'begin', 'end', 'interval', 'expr', 'omit', - 'weights', 'description', 'continuous', - 'percentNulls', 'template', 'format', - 'uniqueValues', 'dataRange', 'text', - 'precision', 'scale', - 'randomSeedMethod', 'randomSeed', - 'nullable', 'implicit', 'escapeSpecialChars', - 'suffix', 'textSeparator' - } + _ALLOWED_PROPERTIES = { + 'name', + 'type', + 'minValue', + 'maxValue', + 'step', + 'prefix', + 'random', + 'distribution', + 'range', + 'baseColumn', + 'baseColumnType', + 'values', + 'numColumns', + 'numFeatures', + 'structType', + 'begin', + 'end', + 'interval', + 'expr', + 'omit', + 'weights', + 'description', + 'continuous', + 'percentNulls', + 'template', + 'format', + 'uniqueValues', + 'dataRange', + 'text', + 'precision', + 'scale', + 'randomSeedMethod', + 'randomSeed', + 'nullable', + 'implicit', + 'escapeSpecialChars', + 'suffix', + 'textSeparator', + } #: the set of disallowed column attributes for any call to data generator `withColumn` or `withColumnSpec` - _FORBIDDEN_PROPERTIES = { - 'range' - } + _FORBIDDEN_PROPERTIES = {'range'} #: maxValue values for each column type, only if where value is intentionally restricted - _MAX_TYPE_RANGE = { - 'byte': 256, - 'short': 65536, - 'int': 4294967296 - } + _MAX_TYPE_RANGE = {'byte': 256, 'short': 65536, 'int': 4294967296} def __init__(self, props, aliases=None): # TODO: check if additional options are needed here as `**kwArgs` self._options = props @@ -203,15 +223,15 @@ def __init__(self, props, aliases=None): # TODO: check if additional options ar @property def options(self): - """ Get options dictionary for object + """Get options dictionary for object - :return: options dictionary for object + :return: options dictionary for object """ return self._options def getOrElse(self, key, default=None): - """ Get val for key if it exists or else return default""" + """Get val for key if it exists or else return default""" assert key is not None, "key must be valid key string" if key in self._options: @@ -221,12 +241,12 @@ def getOrElse(self, key, default=None): return default def __getitem__(self, key): - """ implement the built in dereference by key behavior """ + """implement the built in dereference by key behavior""" ensure(key is not None, "key should be non-empty") return self._options.get(key, None) def checkBoolOption(self, v, name=None, optional=True): - """ Check that option is either not specified or of type boolean + """Check that option is either not specified or of type boolean :param v: value to test :param name: name of value to use in any reported errors or exceptions @@ -235,11 +255,12 @@ def checkBoolOption(self, v, name=None, optional=True): """ assert name is not None, "`name` must be specified" if optional: - ensure(v is None or type(v) is bool, - f"Option `{name}` must be boolean if specified - value: {v}, type: {type(v)}") + ensure( + v is None or type(v) is bool, + f"Option `{name}` must be boolean if specified - value: {v}, type: {type(v)}", + ) else: - ensure(type(v) is bool, - f"Option `{name}` must be boolean - value: {v}, type: {type(v)}") + ensure(type(v) is bool, f"Option `{name}` must be boolean - value: {v}, type: {type(v)}") def checkExclusiveOptions(self, options): """check if the options are exclusive - i.e only one is not None @@ -248,8 +269,9 @@ def checkExclusiveOptions(self, options): """ assert options is not None, "options must be non empty" assert type(options) is list, "`options` must be list" - assert len([self[x] for x in options if self[x] is not None]) <= 1, \ - f" only one of of the options: {options} may be specified " + assert ( + len([self[x] for x in options if self[x] is not None]) <= 1 + ), f" only one of of the options: {options} may be specified " def checkOptionValues(self, option, option_values): """check if option value is in list of values @@ -263,10 +285,10 @@ def checkOptionValues(self, option, option_values): def checkValidColumnProperties(self, columnProps): """ - check that column definition properties are recognized - and that the column definition has required properties + check that column definition properties are recognized + and that the column definition has required properties - :param columnProps: + :param columnProps: """ ensure(columnProps is not None, "columnProps should be non-empty") @@ -281,8 +303,10 @@ def checkValidColumnProperties(self, columnProps): raise ValueError("Effective range greater than range of type") for k in columnProps.keys(): - ensure(k in ColumnSpecOptions._ALLOWED_PROPERTIES or k in ColumnSpecOptions._PROPERTY_ALIASES, - f"invalid column option {k}") + ensure( + k in ColumnSpecOptions._ALLOWED_PROPERTIES or k in ColumnSpecOptions._PROPERTY_ALIASES, + f"invalid column option {k}", + ) for arg in self._REQUIRED_PROPERTIES: ensure(columnProps.get(arg) is not None, f"missing column option {arg}") @@ -292,9 +316,15 @@ def checkValidColumnProperties(self, columnProps): # check weights and values if 'weights' in columnProps: - ensure('values' in columnProps, - f"weights are only allowed for columns with values - column '{columnProps['name']}' ") - ensure(columnProps['values'] is not None and len(columnProps['values']) > 0, - f"weights must be associated with non-empty list of values - column '{columnProps['name']}' ") - ensure(len(columnProps['values']) == len(columnProps['weights']), - f"length(list of weights) must equal length(list of values) - column '{columnProps['name']}' ") + ensure( + 'values' in columnProps, + f"weights are only allowed for columns with values - column '{columnProps['name']}' ", + ) + ensure( + columnProps['values'] is not None and len(columnProps['values']) > 0, + f"weights must be associated with non-empty list of values - column '{columnProps['name']}' ", + ) + ensure( + len(columnProps['values']) == len(columnProps['weights']), + f"length(list of weights) must equal length(list of values) - column '{columnProps['name']}' ", + ) diff --git a/dbldatagen/config.py b/dbldatagen/config.py index a5fa7ecb..b1b835f6 100644 --- a/dbldatagen/config.py +++ b/dbldatagen/config.py @@ -20,6 +20,7 @@ class OutputDataset: :param format: Output data format (default is ``"delta"``). :param options: Optional dictionary of options for writing data (e.g. ``{"mergeSchema": "true"}``) """ + location: str output_mode: str = "append" format: str = "delta" diff --git a/dbldatagen/constraints/__init__.py b/dbldatagen/constraints/__init__.py index 5eb29438..4b1439b1 100644 --- a/dbldatagen/constraints/__init__.py +++ b/dbldatagen/constraints/__init__.py @@ -30,11 +30,14 @@ from .sql_expr import SqlExpr from .unique_combinations import UniqueCombinations -__all__ = ["chained_relation", - "constraint", - "negative_values", - "literal_range_constraint", - "literal_relation_constraint", - "positive_values", - "ranged_values_constraint", - "unique_combinations"] + +__all__ = [ + "chained_relation", + "constraint", + "literal_range_constraint", + "literal_relation_constraint", + "negative_values", + "positive_values", + "ranged_values_constraint", + "unique_combinations", +] diff --git a/dbldatagen/constraints/chained_relation.py b/dbldatagen/constraints/chained_relation.py index 568ff6ea..b5f66f35 100644 --- a/dbldatagen/constraints/chained_relation.py +++ b/dbldatagen/constraints/chained_relation.py @@ -3,32 +3,33 @@ # """ -This module defines the ChainedInequality class +This module defines the ChainedRelation class """ -from pyspark.sql import DataFrame +from typing import Any + import pyspark.sql.functions as F -from .constraint import Constraint, NoPrepareTransformMixin -from ..serialization import SerializableToDict +from pyspark.sql import Column +from dbldatagen.constraints.constraint import Constraint, NoPrepareTransformMixin +from dbldatagen.serialization import SerializableToDict -class ChainedRelation(NoPrepareTransformMixin, Constraint): - """ChainedRelation constraint - Constrains one or more columns so that each column has a relationship to the next. +class ChainedRelation(NoPrepareTransformMixin, Constraint): + """ChainedRelation Constraint object - constrains one or more columns so that each column has a relationship to the next. - For example if the constraint is defined as `ChainedRelation(['a', 'b','c'], "<")` then only rows that - satisfy the condition `a < b < c` will be included in the output - (where `a`, `b` and `c` represent the data values for the rows). + Constrains one or more columns so that each column has a relationship to the next. For example, if the constraint is defined as `ChainedRelation(['a', 'b','c'], "<")` then only rows that + satisfy the condition `a < b < c` will be included in the output (where `a`, `b` and `c` represent the data values for the columns). This can be used to model time related transactions (for example in retail where the purchaseDate, shippingDate and returnDate all have a specific relationship) etc. - Relations supported include <, <=, >=, >, !=, == + Relations supported include "<", "<=", ">", ">=", "!=", "==". - :param columns: column name or list of column names as string or list of strings - :param relation: operator to check - should be one of <,> , =,>=,<=, ==, != + :param columns: Column name or list of column names as string or list of strings + :param relation: Operator to check - should be one of "<", "<=", ">", ">=", "!=", "==" """ - def __init__(self, columns, relation): + + def __init__(self, columns: str | list[str], relation: str) -> None: super().__init__(supportsStreaming=True) self._relation = relation self._columns = self._columnsFromListOrString(columns) @@ -39,20 +40,21 @@ def __init__(self, columns, relation): if not isinstance(self._columns, list) or len(self._columns) <= 1: raise ValueError("ChainedRelation constraints must be defined across more than one column") - def _toInitializationDict(self): - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + def _toInitializationDict(self) -> dict[str, Any]: + """Converts an object to a Python dictionary. Keys represent the object's + constructor arguments. + + :return: Python dictionary representation of the object """ _options = {"kind": self.__class__.__name__, "relation": self._relation, "columns": self._columns} return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } - def _generateFilterExpression(self): - """ Generated composite filter expression for chained set of filter expressions + def _generateFilterExpression(self) -> Column | None: + """Generated composite filter expression for chained set of filter expressions I.e if columns is ['a', 'b', 'c'] and relation is '<' diff --git a/dbldatagen/constraints/constraint.py b/dbldatagen/constraints/constraint.py index 4e8b00e6..e3cb8781 100644 --- a/dbldatagen/constraints/constraint.py +++ b/dbldatagen/constraints/constraint.py @@ -5,49 +5,61 @@ """ This module defines the Constraint class """ -import types +from __future__ import annotations + from abc import ABC, abstractmethod -from pyspark.sql import Column -from ..serialization import SerializableToDict +from types import GeneratorType +from typing import TYPE_CHECKING, ClassVar + +from pyspark.sql import Column, DataFrame + +from dbldatagen.serialization import SerializableToDict + + +if TYPE_CHECKING: + # Imported only for type checking to avoid circular dependency at runtime + from dbldatagen.data_generator import DataGenerator class Constraint(SerializableToDict, ABC): - """ Constraint object - base class for predefined and custom constraints + """Constraint object - base class for predefined and custom constraints. This class is meant for internal use only. - """ - SUPPORTED_OPERATORS = ["<", ">", ">=", "!=", "==", "=", "<=", "<>"] - def __init__(self, supportsStreaming=False): - """ - Initialize the constraint object - """ - self._filterExpression = None - self._calculatedFilterExpression = False + SUPPORTED_OPERATORS: ClassVar[list[str]] = ["<", ">", ">=", "!=", "==", "=", "<=", "<>"] + _filterExpression: Column | None = None + _calculatedFilterExpression: bool = False + _supportsStreaming: bool = False + + def __init__(self, supportsStreaming: bool = False) -> None: self._supportsStreaming = supportsStreaming @staticmethod - def _columnsFromListOrString(columns): - """ Get columns as list of columns from string of list-like + def _columnsFromListOrString( + columns: str | list[str] | set[str] | tuple[str] | GeneratorType | None, + ) -> list[str]: + """Get a list of columns from string or list-like object - :param columns: string or list of strings representing column names + :param columns: String or list-like object representing column names + :return: List of column names """ + if columns is None: + raise ValueError("Columns must be a non-empty string or list-like of column names") + if isinstance(columns, str): return [columns] - elif isinstance(columns, (list, set, tuple, types.GeneratorType)): - return list(columns) - else: - raise ValueError("Columns must be a string or list of strings") + + return list(columns) @staticmethod - def _generate_relation_expression(column, relation, valueExpression): - """ Generate comparison expression + def _generate_relation_expression(column: Column, relation: str, valueExpression: Column) -> Column: + """Generate comparison expression :param column: Column to generate comparison against - :param relation: relation to implement - :param valueExpression: expression to compare to - :return: relation expression as variation of Pyspark SQL columns + :param relation: Relation to implement + :param valueExpression: Expression to compare to + :return: Relation expression as variation of Pyspark SQL columns """ if relation == ">": return column > valueExpression @@ -65,18 +77,17 @@ def _generate_relation_expression(column, relation, valueExpression): raise ValueError(f"Unsupported relation type '{relation}") @staticmethod - def mkCombinedConstraintExpression(constraintExpressions): - """ Generate a SQL expression that combines multiple constraints using AND - - :param constraintExpressions: list of Pyspark SQL Column constraint expression objects - :return: combined constraint expression as Pyspark SQL Column object (or None if no valid expressions) + def mkCombinedConstraintExpression(constraintExpressions: list[Column] | None) -> Column | None: + """Generate a SQL expression that combines multiple constraints using AND. + :param constraintExpressions: List of Pyspark SQL Column constraint expression objects + :return: Combined constraint expression as Pyspark SQL Column object (or None if no valid expressions) """ - assert constraintExpressions is not None and isinstance(constraintExpressions, list), \ - "Constraints must be a list of Pyspark SQL Column instances" + if constraintExpressions is None or not isinstance(constraintExpressions, list): + raise ValueError("Constraints must be a list of Pyspark SQL Column instances") - assert all(expr is None or isinstance(expr, Column) for expr in constraintExpressions), \ - "Constraint expressions must be Pyspark SQL columns or None" + if not all(expr is None or isinstance(expr, Column) for expr in constraintExpressions): + raise ValueError("Constraint expressions must be Pyspark SQL columns or None") valid_constraint_expressions = [expr for expr in constraintExpressions if expr is not None] @@ -87,56 +98,64 @@ def mkCombinedConstraintExpression(constraintExpressions): combined_constraint_expression = combined_constraint_expression & additional_constraint return combined_constraint_expression - else: - return None - @abstractmethod - def prepareDataGenerator(self, dataGenerator): - """ Prepare the data generator to generate data that matches the constraint + return None - This method may modify the data generation rules to meet the constraint + @abstractmethod + def prepareDataGenerator(self, dataGenerator: DataGenerator) -> DataGenerator: + """Prepare the data generator to generate data that matches the constraint. This method may modify the data + generation rules to meet the constraint. - :param dataGenerator: Data generation object that will generate the dataframe - :return: modified or unmodified data generator + :param dataGenerator: Data generation object that will generate the dataframe + :return: Modified or unmodified data generator """ raise NotImplementedError("Method prepareDataGenerator must be implemented in derived class") @abstractmethod - def transformDataframe(self, dataGenerator, dataFrame): - """ Transform the dataframe to make data conform to constraint if possible - - This method should not modify the dataGenerator - but may modify the dataframe + def transformDataframe(self, dataGenerator: DataGenerator, dataFrame: DataFrame) -> DataFrame: + """Transform the dataframe to make data conform to constraint if possible - :param dataGenerator: Data generation object that generated the dataframe - :param dataFrame: generated dataframe - :return: modified or unmodified Spark dataframe - - The default transformation returns the dataframe unmodified + This method should not modify the dataGenerator - but may modify the dataframe. The default + transformation returns the dataframe unmodified + :param dataGenerator: Data generation object that generated the dataframe + :param dataFrame: Generated dataframe + :return: Modified or unmodified Spark dataframe """ raise NotImplementedError("Method transformDataframe must be implemented in derived class") @abstractmethod - def _generateFilterExpression(self): - """ Generate a Pyspark SQL expression that may be used for filtering""" + def _generateFilterExpression(self) -> Column | None: + """Generate a Pyspark SQL expression that may be used for filtering. + + :return: Pyspark SQL expression that may be used for filtering + """ raise NotImplementedError("Method _generateFilterExpression must be implemented in derived class") @property - def supportsStreaming(self): - """ Return True if the constraint supports streaming dataframes""" + def supportsStreaming(self) -> bool: + """Return True if the constraint supports streaming dataframes. + + :return: True if the constraint supports streaming dataframes + """ return self._supportsStreaming @property - def filterExpression(self): - """ Return the filter expression (as instance of type Column that evaluates to True or non-True)""" - if not self._calculatedFilterExpression: - self._filterExpression = self._generateFilterExpression() - self._calculatedFilterExpression = True + def filterExpression(self) -> Column | None: + """Return the filter expression (as instance of type Column that evaluates to True or non-True). + + :return: Filter expression as Pyspark SQL Column object + """ + if self._calculatedFilterExpression: + return self._filterExpression + + self._calculatedFilterExpression = True + self._filterExpression = self._generateFilterExpression() return self._filterExpression class NoFilterMixin: - """ Mixin class to indicate that constraint has no filter expression + """Mixin class to indicate that constraint has no filter expression. Intended to be used in implementation of the concrete constraint classes. @@ -146,13 +165,17 @@ class NoFilterMixin: When using mixins, place the mixin class first in the list of base classes. """ - def _generateFilterExpression(self): - """ Generate a Pyspark SQL expression that may be used for filtering""" + + def _generateFilterExpression(self) -> None: + """Generate a Pyspark SQL expression that may be used for filtering. + + :return: Pyspark SQL expression that may be used for filtering + """ return None class NoPrepareTransformMixin: - """ Mixin class to indicate that constraint has no filter expression + """Mixin class to indicate that constraint has no filter expression Intended to be used in implementation of the concrete constraint classes. @@ -162,26 +185,25 @@ class NoPrepareTransformMixin: When using mixins, place the mixin class first in the list of base classes. """ - def prepareDataGenerator(self, dataGenerator): - """ Prepare the data generator to generate data that matches the constraint - This method may modify the data generation rules to meet the constraint + def prepareDataGenerator(self, dataGenerator: DataGenerator) -> DataGenerator: + """Prepare the data generator to generate data that matches the constraint + + This method may modify the data generation rules to meet the constraint - :param dataGenerator: Data generation object that will generate the dataframe - :return: modified or unmodified data generator + :param dataGenerator: Data generation object that will generate the dataframe + :return: Modified or unmodified data generator """ return dataGenerator - def transformDataframe(self, dataGenerator, dataFrame): - """ Transform the dataframe to make data conform to constraint if possible - - This method should not modify the dataGenerator - but may modify the dataframe - - :param dataGenerator: Data generation object that generated the dataframe - :param dataFrame: generated dataframe - :return: modified or unmodified Spark dataframe + def transformDataframe(self, dataGenerator: DataGenerator, dataFrame: DataFrame) -> DataFrame: + """Transform the dataframe to make data conform to constraint if possible - The default transformation returns the dataframe unmodified + This method should not modify the dataGenerator - but may modify the dataframe. The default + transformation returns the dataframe unmodified + :param dataGenerator: Data generation object that generated the dataframe + :param dataFrame: Generated dataframe + :return: Modified or unmodified Spark dataframe """ return dataFrame diff --git a/dbldatagen/constraints/literal_range_constraint.py b/dbldatagen/constraints/literal_range_constraint.py index f554a472..b1690c3d 100644 --- a/dbldatagen/constraints/literal_range_constraint.py +++ b/dbldatagen/constraints/literal_range_constraint.py @@ -3,18 +3,19 @@ # """ -This module defines the ScalarRange class +This module defines the LiteralRange class """ import pyspark.sql.functions as F +from pyspark.sql import Column -from .constraint import Constraint, NoPrepareTransformMixin -from ..serialization import SerializableToDict +from dbldatagen.constraints.constraint import Constraint, NoPrepareTransformMixin +from dbldatagen.serialization import SerializableToDict class LiteralRange(NoPrepareTransformMixin, Constraint): - """ LiteralRange Constraint object - validates that column value(s) are between 2 literal values + """LiteralRange Constraint object - validates that column value(s) are between 2 literal values. - :param columns: Name of column or list of column names + :param columns: Column name or list of column names as string or list of strings :param lowValue: Tests that columns have values greater than low value (greater or equal if `strict` is False) :param highValue: Tests that columns have values less than high value (less or equal if `strict` is False) :param strict: If True, excludes low and high values from range. Defaults to False @@ -23,33 +24,37 @@ class LiteralRange(NoPrepareTransformMixin, Constraint): `pyspark.sql.functions.lit` function """ - def __init__(self, columns, lowValue, highValue, strict=False): + def __init__(self, columns: str | list[str], lowValue: object, highValue: object, strict: bool = False) -> None: super().__init__(supportsStreaming=True) self._columns = self._columnsFromListOrString(columns) self._lowValue = lowValue self._highValue = highValue self._strict = strict - def _toInitializationDict(self): - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + def _toInitializationDict(self) -> dict[str, object]: + """Converts an object to a Python dictionary. Keys represent the object's + constructor arguments. + + :return: Dictionary representation of the object """ _options = { "kind": self.__class__.__name__, "columns": self._columns, "lowValue": self._lowValue, "highValue": self._highValue, - "strict": self._strict + "strict": self._strict, } return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } - def _generateFilterExpression(self): - """ Generate a SQL filter expression that may be used for filtering""" + def _generateFilterExpression(self) -> Column | None: + """Generate a SQL filter expression that may be used for filtering. + + :return: SQL filter expression as Pyspark SQL Column object + """ expressions = [F.col(colname) for colname in self._columns] minValue = F.lit(self._lowValue) maxValue = F.lit(self._highValue) diff --git a/dbldatagen/constraints/literal_relation_constraint.py b/dbldatagen/constraints/literal_relation_constraint.py index 8017047f..6c0ea7b7 100644 --- a/dbldatagen/constraints/literal_relation_constraint.py +++ b/dbldatagen/constraints/literal_relation_constraint.py @@ -3,25 +3,24 @@ # """ -This module defines the ScalarInequality class +This module defines the LiteralRelation class """ import pyspark.sql.functions as F +from pyspark.sql import Column -from .constraint import Constraint, NoPrepareTransformMixin -from ..serialization import SerializableToDict +from dbldatagen.constraints.constraint import Constraint, NoPrepareTransformMixin +from dbldatagen.serialization import SerializableToDict class LiteralRelation(NoPrepareTransformMixin, Constraint): - """LiteralRelation constraint + """Literal Relation Constraint object - constrains one or more columns so that the columns have an a relationship to a constant value. - Constrains one or more columns so that the columns have an a relationship to a constant value - - :param columns: column name or list of column names - :param relation: operator to check - should be one of <,> , =,>=,<=, ==, != - :param value: A literal value to to compare against + :param columns: Column name or list of column names as string or list of strings + :param relation: Operator to check - should be one of <,> , =,>=,<=, ==, != + :param value: Literal value to to compare against """ - def __init__(self, columns, relation, value): + def __init__(self, columns: str | list[str], relation: str, value: object) -> None: super().__init__(supportsStreaming=True) self._columns = self._columnsFromListOrString(columns) self._relation = relation @@ -30,24 +29,25 @@ def __init__(self, columns, relation, value): if relation not in self.SUPPORTED_OPERATORS: raise ValueError(f"Parameter `relation` should be one of the operators :{self.SUPPORTED_OPERATORS}") - def _toInitializationDict(self): - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + def _toInitializationDict(self) -> dict[str, object]: + """Converts an object to a Python dictionary. Keys represent the object's + constructor arguments. + + :return: Dictionary representation of the object """ _options = { "kind": self.__class__.__name__, "columns": self._columns, "relation": self._relation, - "value": self._value + "value": self._value, } return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } - def _generateFilterExpression(self): + def _generateFilterExpression(self) -> Column | None: expressions = [F.col(colname) for colname in self._columns] literalValue = F.lit(self._value) filters = [self._generate_relation_expression(col, self._relation, literalValue) for col in expressions] diff --git a/dbldatagen/constraints/negative_values.py b/dbldatagen/constraints/negative_values.py index 2e10f6d8..f2ccf4e9 100644 --- a/dbldatagen/constraints/negative_values.py +++ b/dbldatagen/constraints/negative_values.py @@ -3,44 +3,48 @@ # """ -This module defines the Negative class +This module defines the NegativeValues class """ import pyspark.sql.functions as F -from .constraint import Constraint, NoPrepareTransformMixin -from ..serialization import SerializableToDict +from pyspark.sql import Column +from dbldatagen.constraints.constraint import Constraint, NoPrepareTransformMixin +from dbldatagen.serialization import SerializableToDict -class NegativeValues(NoPrepareTransformMixin, Constraint): - """ Negative Value constraints - - Applies constraint to ensure columns have negative values - :param columns: string column name or list of column names as strings - :param strict: if strict is True, the zero value is not considered negative +class NegativeValues(NoPrepareTransformMixin, Constraint): + """Negative Value constraints. - Essentially applies the constraint that the named columns must be less than equal zero - or less than zero if strict has the value `True` + Applies constraint to ensure columns have negative values. Constrains values in named + columns to be less than equal zero or less than zero if strict has the value `True`. + :param columns: Column name or list of column names as string or list of strings + :param strict: If True, the zero value is not considered negative """ - def __init__(self, columns, strict=False): + def __init__(self, columns: str | list[str], strict: bool = False) -> None: super().__init__(supportsStreaming=True) self._columns = self._columnsFromListOrString(columns) self._strict = strict - def _toInitializationDict(self): - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + def _toInitializationDict(self) -> dict[str, object]: + """Converts an object to a Python dictionary. Keys represent the object's + constructor arguments. + + :return: Dictionary representation of the object """ _options = {"kind": self.__class__.__name__, "columns": self._columns, "strict": self._strict} return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } - def _generateFilterExpression(self): + def _generateFilterExpression(self) -> Column | None: + """Generate a SQL filter expression that may be used for filtering. + + :return: SQL filter expression as Pyspark SQL Column object + """ expressions = [F.col(colname) for colname in self._columns] if self._strict: filters = [col.isNotNull() & (col < 0) for col in expressions] diff --git a/dbldatagen/constraints/positive_values.py b/dbldatagen/constraints/positive_values.py index 5a8a3267..4d1b7eb5 100644 --- a/dbldatagen/constraints/positive_values.py +++ b/dbldatagen/constraints/positive_values.py @@ -3,45 +3,48 @@ # """ -This module defines the Positive class +This module defines the PositiveValues class """ import pyspark.sql.functions as F -from .constraint import Constraint, NoPrepareTransformMixin -from ..serialization import SerializableToDict +from pyspark.sql import Column +from dbldatagen.constraints.constraint import Constraint, NoPrepareTransformMixin +from dbldatagen.serialization import SerializableToDict -class PositiveValues(NoPrepareTransformMixin, Constraint): - """ Positive Value constraints - - Applies constraint to ensure columns have positive values - :param columns: string column name or list of column names as strings - :param strict: if strict is True, the zero value is not considered positive +class PositiveValues(NoPrepareTransformMixin, Constraint): + """Positive Value constraints. - Essentially applies the constraint that the named columns must be greater than equal zero - or greater than zero if strict has the value `True` + Applies constraint to ensure columns have positive values. Constrains values in named + columns to be greater than equal zero or greater than zero if strict has the value `True`. + :param columns: Column name or list of column names as string or list of strings + :param strict: If True, the zero value is not considered positive """ - def __init__(self, columns, strict=False): + def __init__(self, columns: str | list[str], strict: bool = False) -> None: super().__init__(supportsStreaming=True) self._columns = self._columnsFromListOrString(columns) self._strict = strict - def _toInitializationDict(self): - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + def _toInitializationDict(self) -> dict[str, object]: + """Converts an object to a Python dictionary. Keys represent the object's + constructor arguments. + + :return: Dictionary representation of the object """ _options = {"kind": self.__class__.__name__, "columns": self._columns, "strict": self._strict} return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } - def _generateFilterExpression(self): - """ Generate a filter expression that may be used for filtering""" + def _generateFilterExpression(self) -> Column | None: + """Generate a SQL filter expression that may be used for filtering. + + :return: SQL filter expression as Pyspark SQL Column object + """ expressions = [F.col(colname) for colname in self._columns] if self._strict: filters = [col.isNotNull() & (col > 0) for col in expressions] diff --git a/dbldatagen/constraints/ranged_values_constraint.py b/dbldatagen/constraints/ranged_values_constraint.py index a727bb7b..81a80a9d 100644 --- a/dbldatagen/constraints/ranged_values_constraint.py +++ b/dbldatagen/constraints/ranged_values_constraint.py @@ -3,52 +3,57 @@ # """ -This module defines the ScalarRange class +This module defines the RangedValues class """ import pyspark.sql.functions as F +from pyspark.sql import Column -from .constraint import Constraint, NoPrepareTransformMixin -from ..serialization import SerializableToDict +from dbldatagen.constraints.constraint import Constraint, NoPrepareTransformMixin +from dbldatagen.serialization import SerializableToDict class RangedValues(NoPrepareTransformMixin, Constraint): - """ RangedValues Constraint object - validates that column value(s) are between 2 column values + """RangedValues Constraint object - validates that column values are in the range defined by values in + `lowValue` and `highValue` columns. `lowValue` and `highValue` must be names of columns that contain + the low and high values respectively. :param columns: Name of column or list of column names - :param lowValue: Tests that columns have values greater than low value (greater or equal if `strict` is False) - :param highValue: Tests that columns have values less than high value (less or equal if `strict` is False) + :param lowValue: Name of column containing the low value + :param highValue: Name of column containing the high value :param strict: If True, excludes low and high values from range. Defaults to False - - Note `lowValue` and `highValue` must be names of columns that contain the low and high values """ - def __init__(self, columns, lowValue, highValue, strict=False): + def __init__(self, columns: str | list[str], lowValue: str, highValue: str, strict: bool = False) -> None: super().__init__(supportsStreaming=True) self._columns = self._columnsFromListOrString(columns) self._lowValue = lowValue self._highValue = highValue self._strict = strict - def _toInitializationDict(self): - """ Returns an internal mapping dictionary for the object. Keys represent the - class constructor arguments and values representing the object's internal data. - :return: Python dictionary mapping constructor options to the object properties + def _toInitializationDict(self) -> dict[str, object]: + """Returns an internal mapping dictionary for the object. Keys represent the + class constructor arguments and values representing the object's internal data. + + :return: Dictionary mapping constructor options to the object properties """ _options = { "kind": self.__class__.__name__, "columns": self._columns, "lowValue": self._lowValue, "highValue": self._highValue, - "strict": self._strict + "strict": self._strict, } return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } - def _generateFilterExpression(self): - """ Generate a SQL filter expression that may be used for filtering""" + def _generateFilterExpression(self) -> Column | None: + """Generate a SQL filter expression that may be used for filtering. + + :return: SQL filter expression as Pyspark SQL Column object + """ expressions = [F.col(colname) for colname in self._columns] minValue = F.col(self._lowValue) maxValue = F.col(self._highValue) diff --git a/dbldatagen/constraints/sql_expr.py b/dbldatagen/constraints/sql_expr.py index cff66fbe..f5bdf18d 100644 --- a/dbldatagen/constraints/sql_expr.py +++ b/dbldatagen/constraints/sql_expr.py @@ -6,38 +6,40 @@ This module defines the SqlExpr class """ import pyspark.sql.functions as F +from pyspark.sql import Column -from .constraint import Constraint, NoPrepareTransformMixin -from ..serialization import SerializableToDict +from dbldatagen.constraints.constraint import Constraint, NoPrepareTransformMixin +from dbldatagen.serialization import SerializableToDict class SqlExpr(NoPrepareTransformMixin, Constraint): - """ SQL Expression Constraint object - - This class represents a constraint that is modelled using a SQL expression + """SQL Expression Constraint object - represents a constraint that is modelled using a SQL expression. :param expr: A SQL expression as a string - """ - def __init__(self, expr: str): + def __init__(self, expr: str) -> None: super().__init__(supportsStreaming=True) - assert expr is not None, "Expression must be a valid SQL string" - assert isinstance(expr, str) and len(expr.strip()) > 0, "Expression must be a valid SQL string" + if not expr: + raise ValueError("Expression must be a valid non-empty SQL string") self._expr = expr - def _toInitializationDict(self): - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + def _toInitializationDict(self) -> dict[str, object]: + """Converts an object to a Python dictionary. Keys represent the object's + constructor arguments. + + :return: Dictionary representation of the object """ _options = {"kind": self.__class__.__name__, "expr": self._expr} return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } - def _generateFilterExpression(self): - """ Generate a SQL filter expression that may be used for filtering""" + def _generateFilterExpression(self) -> Column: + """Generate a SQL filter expression that may be used for filtering. + + :return: SQL filter expression as Pyspark SQL Column object + """ return F.expr(self._expr) diff --git a/dbldatagen/constraints/unique_combinations.py b/dbldatagen/constraints/unique_combinations.py index c6505bd3..74f5a728 100644 --- a/dbldatagen/constraints/unique_combinations.py +++ b/dbldatagen/constraints/unique_combinations.py @@ -3,30 +3,37 @@ # """ -This module defines the Positive class +This module defines the UniqueCombinations class """ -from .constraint import Constraint, NoFilterMixin -from ..serialization import SerializableToDict +from __future__ import annotations +from typing import TYPE_CHECKING -class UniqueCombinations(NoFilterMixin, Constraint): - """ Unique Combinations constraints +from pyspark.sql import DataFrame - Applies constraint to ensure columns have unique combinations - i.e the set of columns supplied only have - one combination of each set of values +from dbldatagen.constraints.constraint import Constraint, NoFilterMixin +from dbldatagen.serialization import SerializableToDict - :param columns: string column name or list of column names as strings.If no columns are specified, all output - columns will be considered when dropping duplicate combinations. - if the columns are not specified, or the column name of '*' is used, - all columns that would be present in the final output are considered. +if TYPE_CHECKING: + # Imported only for type checking to avoid circular dependency at runtime + from dbldatagen.data_generator import DataGenerator + + +class UniqueCombinations(NoFilterMixin, Constraint): + """Unique Combinations Constraint object - ensures that columns have unique combinations of values - i.e + the set of columns supplied only have one combination of each set of values. + + If the columns are not specified, or the column name of '*' is used, all columns that would be present + in the final output are considered. - Essentially applies the constraint that the named columns have unique values for each combination of columns. + The uniqueness constraint may apply to columns that are omitted - i.e not part of the final output. If no + column or column list is supplied, all columns that would be present in the final output are considered. - The uniqueness constraint may apply to columns that are omitted - i.e not part of the final output. - If no column or column list is supplied, all columns that would be present in the final output are considered. + This is useful to enforce unique ids, unique keys, etc. - This is useful to enforce unique ids, unique keys etc. + :param columns: String column name or list of column names as strings. If no columns are specified, all output + columns will be considered when dropping duplicate combinations. ..Note: When applied to streaming dataframe, it will perform any deduplication only within a batch. @@ -39,48 +46,49 @@ class UniqueCombinations(NoFilterMixin, Constraint): """ - def __init__(self, columns=None): + def __init__(self, columns: str | list[str] | None = None) -> None: super().__init__(supportsStreaming=False) - if columns is not None and columns != "*": + if columns and columns != "*": self._columns = self._columnsFromListOrString(columns) else: - self._columns = None + self._columns = [] + + def _toInitializationDict(self) -> dict[str, object]: + """Converts an object to a Python dictionary. Keys represent the object's + constructor arguments. - def _toInitializationDict(self): - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + :return: Dictionary representation of the object """ - _options = {"kind": self.__class__.__name__, "columns": self._columns} + _options = {"kind": self.__class__.__name__, "columns": self._columns or []} return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } - def prepareDataGenerator(self, dataGenerator): - """ Prepare the data generator to generate data that matches the constraint + def prepareDataGenerator(self, dataGenerator: DataGenerator) -> DataGenerator: + """Prepare the data generator to generate data that matches the constraint. - This method may modify the data generation rules to meet the constraint + This method may modify the data generation rules to meet the constraint. - :param dataGenerator: Data generation object that will generate the dataframe - :return: modified or unmodified data generator + :param dataGenerator: Data generation object that will generate the dataframe + :return: Modified or unmodified data generator """ return dataGenerator - def transformDataframe(self, dataGenerator, dataFrame): - """ Transform the dataframe to make data conform to constraint if possible + def transformDataframe(self, dataGenerator: DataGenerator, dataFrame: DataFrame) -> DataFrame: + """Transform the dataframe to make data conform to constraint if possible. - This method should not modify the dataGenerator - but may modify the dataframe + This method should not modify the dataGenerator - but may modify the dataframe. - :param dataGenerator: Data generation object that generated the dataframe - :param dataFrame: generated dataframe - :return: modified or unmodified Spark dataframe + :param dataGenerator: Data generation object that generated the dataframe + :param dataFrame: Generated dataframe + :return: Modified or unmodified Spark dataframe - The default transformation returns the dataframe unmodified + The default transformation returns the dataframe unmodified """ - if self._columns is None: + if not self._columns: # if no columns are specified, then all columns that would appear in the final output are used # when determining duplicates columnsToEvaluate = dataGenerator.getOutputColumnNames() @@ -88,7 +96,4 @@ def transformDataframe(self, dataGenerator, dataFrame): columnsToEvaluate = self._columns # for batch processing, duplicate rows will be removed via drop duplicates - - results = dataFrame.dropDuplicates(columnsToEvaluate) - - return results + return dataFrame.dropDuplicates(columnsToEvaluate) diff --git a/dbldatagen/data_analyzer.py b/dbldatagen/data_analyzer.py index 5dcaaf92..01a4760a 100644 --- a/dbldatagen/data_analyzer.py +++ b/dbldatagen/data_analyzer.py @@ -34,6 +34,7 @@ class DataAnalyzer: .. warning:: Experimental """ + debug: bool verbose: bool _sparkSession: SparkSession @@ -50,7 +51,7 @@ class DataAnalyzer: |# Github project - [https://github.com/databrickslabs/dbldatagen] |# """, - marginChar="|" + marginChar="|", ) _GENERATED_FROM_SCHEMA_COMMENT: str = strip_margins( @@ -58,7 +59,7 @@ class DataAnalyzer: |# Column definitions are stubs only - modify to generate correct data |# """, - marginChar="|" + marginChar="|", ) def __init__( @@ -66,7 +67,7 @@ def __init__( df: DataFrame | None = None, sparkSession: SparkSession | None = None, debug: bool = False, - verbose: bool = False + verbose: bool = False, ) -> None: self.verbose = verbose self.debug = debug @@ -114,7 +115,7 @@ def _addMeasureToSummary( fieldExprs: list[str] | None = None, dfData: DataFrame | None, rowLimit: int = 1, - dfSummary: DataFrame | None = None + dfSummary: DataFrame | None = None, ) -> DataFrame: """ Adds a new measure to the summary ``DataFrame``. @@ -192,7 +193,7 @@ def summarizeToDF(self) -> DataFrame: measureName="schema", summaryExpr=f"""to_json(named_struct('column_count', {len(dtypes)}))""", fieldExprs=[f"'{dtype[1]}' as {dtype[0]}" for dtype in dtypes], - dfData=self._df + dfData=self._df, ) data_summary_df = self._addMeasureToSummary( @@ -200,17 +201,17 @@ def summarizeToDF(self) -> DataFrame: summaryExpr=f"{total_count}", fieldExprs=[f"string(count({dtype[0]})) as {dtype[0]}" for dtype in dtypes], dfData=self._df, - dfSummary=data_summary_df + dfSummary=data_summary_df, ) data_summary_df = self._addMeasureToSummary( measureName="null_probability", fieldExprs=[ f"""string( round( ({total_count} - count({dtype[0]})) /{total_count}, 2)) as {dtype[0]}""" - for dtype in dtypes + for dtype in dtypes ], dfData=self._df, - dfSummary=data_summary_df + dfSummary=data_summary_df, ) # distinct count @@ -219,7 +220,7 @@ def summarizeToDF(self) -> DataFrame: summaryExpr="count(distinct *)", fieldExprs=[f"string(count(distinct {dtype[0]})) as {dtype[0]}" for dtype in dtypes], dfData=self._df, - dfSummary=data_summary_df + dfSummary=data_summary_df, ) # min @@ -227,20 +228,18 @@ def summarizeToDF(self) -> DataFrame: measureName="min", fieldExprs=[f"string(min({dtype[0]})) as {dtype[0]}" for dtype in dtypes], dfData=self._df, - dfSummary=data_summary_df + dfSummary=data_summary_df, ) data_summary_df = self._addMeasureToSummary( measureName="max", fieldExprs=[f"string(max({dtype[0]})) as {dtype[0]}" for dtype in dtypes], dfData=self._df, - dfSummary=data_summary_df + dfSummary=data_summary_df, ) - description_df = ( - self - ._get_dataframe_describe_stats(self._df) - .where(f"{DATA_SUMMARY_FIELD_NAME} in ('mean', 'stddev')") + description_df = self._get_dataframe_describe_stats(self._df).where( + f"{DATA_SUMMARY_FIELD_NAME} in ('mean', 'stddev')" ) description_data = description_df.collect() @@ -257,7 +256,7 @@ def summarizeToDF(self) -> DataFrame: measureName=measure, fieldExprs=[f"'{values[dtype[0]]}'" for dtype in dtypes], dfData=self._df, - dfSummary=data_summary_df + dfSummary=data_summary_df, ) # string characteristics for strings and string representation of other values @@ -265,14 +264,14 @@ def summarizeToDF(self) -> DataFrame: measureName="print_len_min", fieldExprs=[f"string(min(length(string({dtype[0]})))) as {dtype[0]}" for dtype in dtypes], dfData=self._df, - dfSummary=data_summary_df + dfSummary=data_summary_df, ) data_summary_df = self._addMeasureToSummary( measureName="print_len_max", fieldExprs=[f"string(max(length(string({dtype[0]})))) as {dtype[0]}" for dtype in dtypes], dfData=self._df, - dfSummary=data_summary_df + dfSummary=data_summary_df, ) return data_summary_df @@ -287,10 +286,7 @@ def summarize(self, suppressOutput: bool = False) -> str: """ summary_df = self.summarizeToDF() - results = [ - "Data set summary", - "================" - ] + results = ["Data set summary", "================"] for row in summary_df.collect(): results.append(self._displayRow(row)) @@ -308,7 +304,7 @@ def _valueFromSummary( dataSummary: dict[str, dict[str, object]] | None = None, colName: str | None = None, measure: str | None = None, - defaultValue: int | float | str | None = None + defaultValue: int | float | str | None = None, ) -> object: """ Gets a measure value from a data summary given a measure name and column name. Returns a default value when the @@ -334,10 +330,7 @@ def _valueFromSummary( @classmethod def _generatorDefaultAttributesFromType( - cls, - sqlType: types.DataType, - colName: str | None = None, - dataSummary: dict | None = None + cls, sqlType: types.DataType, colName: str | None = None, dataSummary: dict | None = None ) -> str: """ Generates a Spark SQL expression for the input column and data type. Optionally uses ``DataAnalyzer`` summary @@ -421,7 +414,7 @@ def _scriptDataGeneratorCode( dataSummary: dict | None = None, sourceDf: DataFrame | None = None, suppressOutput: bool = False, - name: str | None = None + name: str | None = None, ) -> str: """ Generates code to build a ``DataGenerator`` from an existing dataframe. Analyzes the dataframe passed to the @@ -455,7 +448,7 @@ def _scriptDataGeneratorCode( | rows=100000, | random=True, | )""", - marginChar="|" + marginChar="|", ) ) diff --git a/dbldatagen/data_generator.py b/dbldatagen/data_generator.py index 14fa92e0..97e7b4fc 100644 --- a/dbldatagen/data_generator.py +++ b/dbldatagen/data_generator.py @@ -103,9 +103,9 @@ def __init__( debug: bool = False, seedColumnName: str = datagen_constants.DEFAULT_SEED_COLUMN, random: bool = False, - **kwargs + **kwargs, ) -> None: - """ Constructor for data generator object """ + """Constructor for data generator object""" # set up logging self.verbose = verbose @@ -187,7 +187,9 @@ def __init__( self._seedMethod = "fixed" allowed_seed_methods = [ - None, datagen_constants.RANDOM_SEED_FIXED, datagen_constants.RANDOM_SEED_HASH_FIELD_NAME + None, + datagen_constants.RANDOM_SEED_FIXED, + datagen_constants.RANDOM_SEED_HASH_FIELD_NAME, ] if self._seedMethod not in allowed_seed_methods: msg = f"seedMethod should be None, '{datagen_constants.RANDOM_SEED_FIXED}' or '{datagen_constants.RANDOM_SEED_HASH_FIELD_NAME}' " @@ -266,14 +268,16 @@ def _toInitializationDict(self) -> dict[str, Any]: "startingId": self.starting_id, "randomSeed": self._randomSeed, "partitions": self.partitions, - "verbose": self.verbose, "batchSize": self._batchSize, "debug": self.debug, + "verbose": self.verbose, + "batchSize": self._batchSize, + "debug": self.debug, "seedColumnName": self._seedColumnName, "random": self._defaultRandom, - "columns": [{ - k: v for k, v in column._toInitializationDict().items() - if k != "kind"} - for column in self.columnGenerationSpecs], - "constraints": [constraint._toInitializationDict() for constraint in self.constraints] + "columns": [ + {k: v for k, v in column._toInitializationDict().items() if k != "kind"} + for column in self.columnGenerationSpecs + ], + "constraints": [constraint._toInitializationDict() for constraint in self.constraints], } return _options @@ -309,8 +313,7 @@ def _checkSparkVersion(cls, sparkVersion: str, minSparkVersion: tuple[int, int, if spark_version_info < minSparkVersion: logging.warning( - f"*** Minimum version of Python supported is {minSparkVersion} - found version %s ", - spark_version_info + f"*** Minimum version of Python supported is {minSparkVersion} - found version %s ", spark_version_info ) return False @@ -484,7 +487,7 @@ def explain(self, suppressOutput: bool = False) -> str: "", f"column build order: {self._buildOrder}", "", - "build plan:" + "build plan:", ] for plan_action in self._buildPlan: @@ -602,7 +605,7 @@ def describe(self) -> dict[str, Any]: "partitions": self.partitions, "columnDefinitions": self._columnSpecsByName, "debug": self.debug, - "verbose": self.verbose + "verbose": self.verbose, } def __repr__(self) -> str: @@ -661,7 +664,7 @@ def inferredSchema(self) -> StructType: return StructType(self._inferredSchemaFields) def __getitem__(self, key: str) -> ColumnGenerationSpec: - """ implement the built-in dereference by key behavior """ + """implement the built-in dereference by key behavior""" ensure(key is not None, "key should be non-empty") return self._columnSpecsByName[key] @@ -724,10 +727,13 @@ def getOutputColumnNames(self) -> list[str]: :returns: List of column names in the `DataGenerator's` output `DataFrame` """ - return self.flatten([ - self._columnSpecsByName[fd.name].getNames() for fd in - self._inferredSchemaFields if not self._columnSpecsByName[fd.name].isFieldOmitted - ]) + return self.flatten( + [ + self._columnSpecsByName[fd.name].getNames() + for fd in self._inferredSchemaFields + if not self._columnSpecsByName[fd.name].isFieldOmitted + ] + ) def getOutputColumnNamesAndTypes(self) -> list[tuple[str, DataType]]: """ @@ -736,9 +742,13 @@ def getOutputColumnNamesAndTypes(self) -> list[tuple[str, DataType]]: :returns: A list of tuples of column name and data type """ - return self.flatten([self._columnSpecsByName[fd.name].getNamesAndTypes() - for fd in - self._inferredSchemaFields if not self._columnSpecsByName[fd.name].isFieldOmitted]) + return self.flatten( + [ + self._columnSpecsByName[fd.name].getNamesAndTypes() + for fd in self._inferredSchemaFields + if not self._columnSpecsByName[fd.name].isFieldOmitted + ] + ) def withSchema(self, sch: StructType) -> "DataGenerator": """ @@ -760,7 +770,7 @@ def _computeRange( dataRange: DataRange | range | None = None, minValue: int | float | complex | date | datetime | None = None, maxValue: int | float | complex | date | datetime | None = None, - step: int | float | complex | timedelta | None = None + step: int | float | complex | timedelta | None = None, ) -> tuple[Any, Any, Any]: """ Computes a numeric range from the input parameters. @@ -783,7 +793,7 @@ def withColumnSpecs( patterns: str | list[str] | None = None, fields: str | list[str] | None = None, matchTypes: str | list[str] | DataType | list[DataType] | None = None, - **kwargs + **kwargs, ) -> "DataGenerator": """ Adds column specs for columns matching: @@ -822,8 +832,7 @@ def withColumnSpecs( patterns = ["^" + pat + "$" for pat in patterns] all_fields = self.getInferredColumnNames() - effective_fields = [x for x in all_fields if - (fields is None or x in fields) and x != self._seedColumnName] + effective_fields = [x for x in all_fields if (fields is None or x in fields) and x != self._seedColumnName] if patterns: effective_fields = [x for x in effective_fields for y in patterns if re.search(y, x) is not None] @@ -841,9 +850,7 @@ def withColumnSpecs( else: effective_types.append(match_type) - effective_fields = [ - x for x in effective_fields for y in effective_types if self.getColumnType(x) == y - ] + effective_fields = [x for x in effective_fields for y in effective_types if self.getColumnType(x) == y] for f in effective_fields: self.withColumnSpec(f, implicit=True, **kwargs) @@ -866,7 +873,7 @@ def _checkColumnOrColumnList(self, columns: str | list[str], allowId: bool = Fal for column in columns: ensure(column in inferred_columns, f" column `{column}` must refer to defined column") else: - ensure(columns in inferred_columns,f" column `{columns}` must refer to defined column") + ensure(columns in inferred_columns, f" column `{columns}` must refer to defined column") return True def withColumnSpec( @@ -883,7 +890,7 @@ def withColumnSpec( dataRange: DataRange | None = None, omit: bool = False, baseColumn: str | None = None, - **kwargs + **kwargs, ) -> "DataGenerator": """ Adds a `ColumnGenerationSpec` for an existing column. @@ -905,7 +912,7 @@ def withColumnSpec( not isinstance(minValue, DataType), """unnecessary `datatype` argument specified for `withColumnSpec` for column `{colName}` - Datatype parameter is only needed for `withColumn` and not permitted for `withColumnSpec` - """ + """, ) if random is None: @@ -913,14 +920,16 @@ def withColumnSpec( # handle migration of old `min` and `max` options if _OLD_MIN_OPTION in kwargs: - assert minValue is None, \ - "Only one of `minValue` and `minValue` can be specified. Use of `minValue` is preferred" + assert ( + minValue is None + ), "Only one of `minValue` and `minValue` can be specified. Use of `minValue` is preferred" minValue = kwargs[_OLD_MIN_OPTION] kwargs.pop(_OLD_MIN_OPTION, None) if _OLD_MAX_OPTION in kwargs: - assert maxValue is None, \ - "Only one of `maxValue` and `maxValue` can be specified. Use of `maxValue` is preferred" + assert ( + maxValue is None + ), "Only one of `maxValue` and `maxValue` can be specified. Use of `maxValue` is preferred" maxValue = kwargs[_OLD_MAX_OPTION] kwargs.pop(_OLD_MAX_OPTION, None) @@ -936,14 +945,15 @@ def withColumnSpec( self.getColumnType(colName), minValue=minValue, maxValue=maxValue, - step=step, prefix=prefix, + step=step, + prefix=prefix, random=random, dataRange=dataRange, distribution=distribution, baseColumn=baseColumn, implicit=implicit, omit=omit, - **new_props + **new_props, ) return self @@ -964,7 +974,8 @@ def withColumn( minValue: int | float | complex | date | datetime | None = None, maxValue: int | float | complex | date | datetime | None = None, step: int | float | complex | timedelta | None = 1, - dataRange: DataRange | None = None, prefix: str | None = None, + dataRange: DataRange | None = None, + prefix: str | None = None, random: bool | None = None, distribution: DataDistribution | None = None, baseColumn: str | None = None, @@ -972,7 +983,7 @@ def withColumn( omit: bool = False, implicit: bool = False, noWarn: bool = False, - **kwargs + **kwargs, ) -> "DataGenerator": """ Adds a new column generation specification to the `DataGenerator`. @@ -1022,14 +1033,16 @@ def withColumn( # handle migration of old `min` and `max` options if _OLD_MIN_OPTION in kwargs: - assert minValue is None, \ - "Only one of `minValue` and `minValue` can be specified. Use of `minValue` is preferred" + assert ( + minValue is None + ), "Only one of `minValue` and `minValue` can be specified. Use of `minValue` is preferred" minValue = kwargs[_OLD_MIN_OPTION] kwargs.pop(_OLD_MIN_OPTION, None) if _OLD_MAX_OPTION in kwargs: - assert maxValue is None, \ - "Only one of `maxValue` and `maxValue` can be specified. Use of `maxValue` is preferred" + assert ( + maxValue is None + ), "Only one of `maxValue` and `maxValue` can be specified. Use of `maxValue` is preferred" maxValue = kwargs[_OLD_MAX_OPTION] kwargs.pop(_OLD_MAX_OPTION, None) @@ -1040,8 +1053,9 @@ def withColumn( new_props.update(kwargs) self.logger.info(f"effective range: {minValue}, {maxValue}, {step} args: {kwargs}") - self.logger.info("adding column - `%s` with baseColumn : `%s`, implicit : %s , omit %s", - colName, baseColumn, implicit, omit) + self.logger.info( + "adding column - `%s` with baseColumn : `%s`, implicit : %s , omit %s", colName, baseColumn, implicit, omit + ) newColumn = self._generateColumnDefinition( colName, colType, @@ -1055,7 +1069,7 @@ def withColumn( dataRange=dataRange, implicit=implicit, omit=omit, - **new_props + **new_props, ) # note for inferred columns, the column type is initially sey to a StringType but may be superceded later @@ -1076,9 +1090,7 @@ def _loadColumnsFromInitializationDicts(self, columns: list[dict[str, Any]]) -> if not isinstance(v, dict): continue value_superclass = ( - DataRange if k == "dataRange" - else DataDistribution if k == "distribution" - else TextGenerator + DataRange if k == "dataRange" else DataDistribution if k == "distribution" else TextGenerator ) value_subclasses = value_superclass.__subclasses__() if v["kind"] not in [s.__name__ for s in value_subclasses]: @@ -1108,8 +1120,9 @@ def _mkSqlStructFromList(self, fields: list[str | tuple[str, str]]) -> str: name of the field within the struct. The second element must be a SQL expression that will be used to generate the field value, and may reference previously defined columns. """ - assert fields is not None and isinstance(fields, list), \ - "Fields must be a non-empty list of fields that make up the struct elements" + assert fields is not None and isinstance( + fields, list + ), "Fields must be a non-empty list of fields that make up the struct elements" assert len(fields) >= 1, "Fields must be a non-empty list of fields that make up the struct elements" struct_expressions = [] @@ -1129,8 +1142,9 @@ def _mkSqlStructFromList(self, fields: list[str | tuple[str, str]]) -> str: return struct_expression def _mkStructFromDict(self, fields: dict[str, Any]) -> str: - assert fields is not None and isinstance(fields, dict), \ - "Fields must be a non-empty dict of fields that make up the struct elements" + assert fields is not None and isinstance( + fields, dict + ), "Fields must be a non-empty dict of fields that make up the struct elements" struct_expressions = [] for key, value in fields.items(): @@ -1153,7 +1167,7 @@ def withStructColumn( colName: str, fields: list[str | tuple[str, str]] | dict[str, Any] | None = None, asJson: bool = False, - **kwargs + **kwargs, ) -> "DataGenerator": """ Adds a struct column to the synthetic data generation specification. This will add a new column composed of @@ -1179,13 +1193,13 @@ def withStructColumn( When using the `dict` form of the field specifications, a field whose value is a list will be treated as creating a SQL array literal. """ - assert fields is not None and isinstance(fields, list | dict), \ - "Fields argument must be a list of field specifications or dict outlining the target structure " + assert fields is not None and isinstance( + fields, list | dict + ), "Fields argument must be a list of field specifications or dict outlining the target structure " assert isinstance(colName, str) and len(colName) > 0, "Must specify a column name" if isinstance(fields, list): - assert len(fields) > 0, \ - "Must specify at least one field for struct column" + assert len(fields) > 0, "Must specify at least one field for struct column" struct_expr = self._mkSqlStructFromList(fields) elif isinstance(fields, dict): struct_expr = self._mkStructFromDict(fields) @@ -1209,9 +1223,9 @@ def _generateColumnDefinition( implicit: bool = False, omit: bool = False, nullable: bool = True, - **kwargs + **kwargs, ) -> ColumnGenerationSpec: - """ generate field definition and column spec + """generate field definition and column spec .. note:: Any time that a new column definition is added, we'll mark that the build plan needs to be regenerated. @@ -1267,7 +1281,7 @@ def _generateColumnDefinition( verbose=self.verbose, debug=self.debug, seedColumnName=self._seedColumnName, - **new_props + **new_props, ) self._columnSpecsByName[colName] = column_spec @@ -1285,9 +1299,9 @@ def _generateColumnDefinition( return column_spec def _getBaseDataFrame( - self, startId: int | None = None, streaming: bool = False, options: dict[str, Any] | None = None + self, startId: int | None = None, streaming: bool = False, options: dict[str, Any] | None = None ) -> DataFrame: - """ generate the base data frame and seed column (which defaults to `id`) , partitioning the data if necessary + """generate the base data frame and seed column (which defaults to `id`) , partitioning the data if necessary This is used when generating the test data. @@ -1310,11 +1324,7 @@ def _getBaseDataFrame( f"Generating data frame with ids from {startId} to {end_id} with {id_partitions} partitions" ) - df = self.sparkSession.range( - start=start_id, - end=end_id, - numPartitions=id_partitions - ) + df = self.sparkSession.range(start=start_id, end=end_id, numPartitions=id_partitions) # spark.range generates a dataframe with the column `id` so rename it if its not our seed column if self._seedColumnName != datagen_constants.SPARK_RANGE_COLUMN: @@ -1322,9 +1332,7 @@ def _getBaseDataFrame( return df - status = ( - f"Generating streaming data frame with ids from {startId} to {end_id} with {id_partitions} partitions" - ) + status = f"Generating streaming data frame with ids from {startId} to {end_id} with {id_partitions} partitions" self.logger.info(status) self.executionHistory.append(status) @@ -1341,15 +1349,14 @@ def _getBaseDataFrame( else: return ( - reader - .option("rowsPerSecond", 1) + reader.option("rowsPerSecond", 1) .option("numPartitions", id_partitions) .load() .withColumnRenamed("value", self._seedColumnName) ) def _computeColumnBuildOrder(self) -> list[list[str]]: - """ compute the build ordering using a topological sort on dependencies + """compute the build ordering using a topological sort on dependencies In order to avoid references to columns that have not yet been generated, the test data generation process sorts the columns according to the order they need to be built. @@ -1362,8 +1369,8 @@ def _computeColumnBuildOrder(self) -> list[list[str]]: :returns: the build ordering """ dependency_ordering = [ - (x.name, set(x.dependencies)) if x.name != self._seedColumnName - else (self._seedColumnName, set()) for x in self._allColumnSpecs + (x.name, set(x.dependencies)) if x.name != self._seedColumnName else (self._seedColumnName, set()) + for x in self._allColumnSpecs ] self.logger.info("dependency list: %s", str(dependency_ordering)) @@ -1378,8 +1385,10 @@ def _computeColumnBuildOrder(self) -> list[list[str]]: return self._buildOrder - def _adjustBuildOrderForSqlDependencies(self, buildOrder: list[list[str]], columnSpecsByName: dict[str, ColumnGenerationSpec]) -> list[list[str]]: - """ Adjust column build order according to the following heuristics + def _adjustBuildOrderForSqlDependencies( + self, buildOrder: list[list[str]], columnSpecsByName: dict[str, ColumnGenerationSpec] + ) -> list[list[str]]: + """Adjust column build order according to the following heuristics 1: if the column being built in a specific build order phase has a SQL expression and it references other columns in the same build phase (or potentially references them as the expression parsing is @@ -1488,8 +1497,9 @@ def withConstraint(self, constraint: Constraint) -> "DataGenerator": constraint may also affect other aspects of the data generation. """ assert constraint is not None, "Constraint cannot be empty" - assert isinstance(constraint, Constraint), \ - "Value for 'constraint' must be an instance or subclass of the Constraint class." + assert isinstance( + constraint, Constraint + ), "Value for 'constraint' must be an instance or subclass of the Constraint class." self._constraints.append(constraint) return self @@ -1509,8 +1519,9 @@ def withConstraints(self, constraints: list[Constraint]) -> "DataGenerator": for constraint in constraints: assert constraint is not None, "Constraint cannot be empty" - assert isinstance(constraint, Constraint), \ - "Constraint must be an instance of, or an instance of a subclass of the Constraint class" + assert isinstance( + constraint, Constraint + ), "Constraint must be an instance of, or an instance of a subclass of the Constraint class" self._constraints.extend(constraints) return self @@ -1531,10 +1542,10 @@ def withSqlConstraint(self, sqlExpression: str) -> "DataGenerator": return self def _loadConstraintsFromInitializationDicts(self, constraints: list[dict[str, Any]]) -> "DataGenerator": - """ Adds a set of constraints to the synthetic generation specification. - :param constraints: A list of constraints as dictionaries - :returns: A modified in-place instance of a data generator allowing for chaining of calls - following a builder pattern + """Adds a set of constraints to the synthetic generation specification. + :param constraints: A list of constraints as dictionaries + :returns: A modified in-place instance of a data generator allowing for chaining of calls + following a builder pattern """ for c in constraints: t = next((s for s in Constraint.__subclasses__() if s.__name__ == c["kind"]), Constraint) @@ -1583,7 +1594,7 @@ def computeBuildPlan(self) -> "DataGenerator": return self def _applyPreGenerationConstraints(self, withStreaming: bool = False) -> None: - """ Apply pre data generation constraints """ + """Apply pre data generation constraints""" if self._constraints is not None and len(self._constraints) > 0: for constraint in self._constraints: assert isinstance(constraint, Constraint), "Value for 'constraint' should be of type 'Constraint'" @@ -1592,16 +1603,20 @@ def _applyPreGenerationConstraints(self, withStreaming: bool = False) -> None: constraint.prepareDataGenerator(self) def _applyPostGenerationConstraints(self, df: DataFrame) -> DataFrame: - """ Build and apply the constraints using two mechanisms - - Apply transformations to dataframe - - Apply expressions as SQL filters using where clauses""" + """Build and apply the constraints using two mechanisms + - Apply transformations to dataframe + - Apply expressions as SQL filters using where clauses""" if self._constraints is not None and len(self._constraints) > 0: for constraint in self._constraints: df = constraint.transformDataframe(self, df) # get set of constraint expressions - constraint_expressions = [constraint.filterExpression for constraint in self._constraints] + constraint_expressions = [ + constraint.filterExpression + for constraint in self._constraints + if constraint.filterExpression is not None + ] combined_constraint_expression = Constraint.mkCombinedConstraintExpression(constraint_expressions) # apply the filter @@ -1616,7 +1631,7 @@ def build( withTempView: bool = False, withView: bool = False, withStreaming: bool = False, - options: dict[str, Any] | None = None + options: dict[str, Any] | None = None, ) -> DataFrame: """ Builds a Spark `DataFrame` from the current `DataGenerator`. @@ -1643,11 +1658,13 @@ def build( self.computeBuildPlan() output_columns = self.getOutputColumnNames() - ensure(output_columns is not None and len(output_columns) > 0, - """ + ensure( + output_columns is not None and len(output_columns) > 0, + """ | You must specify at least one column for output | - use withIdOutput() to output base seed column - """) + """, + ) df1 = self._getBaseDataFrame(self.starting_id, streaming=withStreaming, options=options) @@ -1726,7 +1743,9 @@ def _sqlTypeFromSparkType(dt: DataType) -> str: return dt.simpleString() @staticmethod - def _mkInsertOrUpdateStatement(columns: list[str], srcAlias: str, substitutions: list[str] | None, isUpdate: bool = True) -> str: + def _mkInsertOrUpdateStatement( + columns: list[str], srcAlias: str, substitutions: list[str] | None, isUpdate: bool = True + ) -> str: if substitutions is None: substitutions = [] results = [] @@ -1746,11 +1765,7 @@ def _mkInsertOrUpdateStatement(columns: list[str], srcAlias: str, substitutions: return ", ".join(results) def scriptTable( - self, - name: str | None = None, - location: str | None = None, - tableFormat: str = "delta", - asHtml: bool = False + self, name: str | None = None, location: str | None = None, tableFormat: str = "delta", asHtml: bool = False ) -> str: """ Gets a Spark SQL `CREATE TABLE AS SELECT` statement suitable for the format of test data set. @@ -1769,11 +1784,13 @@ def scriptTable( output_columns = self.getOutputColumnNamesAndTypes() results = [f"CREATE TABLE IF NOT EXISTS {name} ("] - ensure(output_columns is not None and len(output_columns) > 0, - """ + ensure( + output_columns is not None and len(output_columns) > 0, + """ | You must specify at least one column for output | - use withIdOutput() to output base seed column - """) + """, + ) col_expressions = [] for col_to_output in output_columns: @@ -1809,7 +1826,7 @@ def scriptMerge( insertColumnExprs: list[str] | None = None, srcAlias: str = "src", tgtAlias: str = "tgt", - asHtml: bool = False + asHtml: bool = False, ) -> str: """ Gets a Spark SQL `MERGE` statement suitable for the format of test dataset. @@ -1850,11 +1867,13 @@ def scriptMerge( # get list of column names output_columns = [x[0] for x in self.getOutputColumnNamesAndTypes()] - ensure(output_columns is not None and len(output_columns) > 0, - """ + ensure( + output_columns is not None and len(output_columns) > 0, + """ | You must specify at least one column for output | - use withIdOutput() to output base seed column - """) + """, + ) # use list of column names if not supplied if insertColumns is None: @@ -1886,8 +1905,9 @@ def scriptMerge( update_clause = update_clause + " SET *" else: update_clause = ( - update_clause + " SET " + - DataGenerator._mkInsertOrUpdateStatement(updateColumns, srcAlias, updateColumnExprs) + update_clause + + " SET " + + DataGenerator._mkInsertOrUpdateStatement(updateColumns, srcAlias, updateColumnExprs) ) results.append(update_clause) @@ -1905,8 +1925,12 @@ def scriptMerge( ins_clause = ins_clause + " *" else: ins_clause = ( - ins_clause + "(" + ",".join(insertColumns) + ") VALUES (" + - DataGenerator._mkInsertOrUpdateStatement(insertColumns, srcAlias, insertColumnExprs, False) + ")" + ins_clause + + "(" + + ",".join(insertColumns) + + ") VALUES (" + + DataGenerator._mkInsertOrUpdateStatement(insertColumns, srcAlias, insertColumnExprs, False) + + ")" ) results.append(ins_clause) @@ -1918,10 +1942,10 @@ def scriptMerge( return result def saveAsDataset( - self, - dataset: OutputDataset, - with_streaming: bool | None = None, - generator_options: dict[str, Any] | None = None + self, + dataset: OutputDataset, + with_streaming: bool | None = None, + generator_options: dict[str, Any] | None = None, ) -> StreamingQuery | None: """ Builds a `DataFrame` from the `DataGenerator` and writes the data to an output dataset (e.g. a table or files). diff --git a/dbldatagen/datagen_types.py b/dbldatagen/datagen_types.py new file mode 100644 index 00000000..60228bc4 --- /dev/null +++ b/dbldatagen/datagen_types.py @@ -0,0 +1,13 @@ +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module defines type aliases for common types used throughout the library. +""" +import numpy as np +from pyspark.sql import Column + + +NumericLike = float | int | np.float64 | np.int32 | np.int64 +ColumnLike = Column | str diff --git a/dbldatagen/datarange.py b/dbldatagen/datarange.py index 1e7c1273..a7ec545b 100644 --- a/dbldatagen/datarange.py +++ b/dbldatagen/datarange.py @@ -3,49 +3,54 @@ # """ -This module defines the DataRange abstract class - -its not used directly, but used as base type for explicit DateRange and NRange types to ensure correct tracking of -changes to method names when refactoring - +This module defines the DataRange abstract class. """ -from .serialization import SerializableToDict +from pyspark.sql.types import DataType + +from dbldatagen.serialization import SerializableToDict class DataRange(SerializableToDict): - """ Abstract class used as base class for NRange and DateRange """ + """Abstract class used as base class for NRange and DateRange""" + + minValue: object | None + maxValue: object | None - def isEmpty(self): - """Check if object is empty (i.e all instance vars of note are `None`)""" - raise NotImplementedError("method not implemented") + def isEmpty(self) -> bool: + """Checks if object is empty (i.e all instance vars of note are `None`). - def isFullyPopulated(self): - """Check is all instance vars are populated""" - raise NotImplementedError("method not implemented") + :return: True if the object is empty + """ + raise NotImplementedError(f"'{self.__class__.__name__}' does not implement method 'isEmpty'") - def adjustForColumnDatatype(self, ctype): - """ Adjust default values for column output type""" - raise NotImplementedError("method not implemented") + def isFullyPopulated(self) -> bool: + """Checks if all instance vars are populated. - def getDiscreteRange(self): - """Convert range to discrete range""" - raise NotImplementedError("method not implemented") + :return: True if all instance vars are populated, False otherwise + """ + raise NotImplementedError(f"'{self.__class__.__name__}' does not implement method 'isFullyPopulated'") - def getContinuousRange(self): - """Convert range to continuous range""" - raise NotImplementedError("method not implemented") + def adjustForColumnDatatype(self, ctype: DataType) -> None: + """Adjust default values for column output type. - def getScale(self): - """Get scale of range""" - raise NotImplementedError("method not implemented") + :param ctype: Spark SQL data type for column + """ + raise NotImplementedError(f"'{self.__class__.__name__}' does not implement method 'adjustForColumnDatatype'") + + def getDiscreteRange(self) -> float: + """Convert range to discrete range. + + :return: Discrete range object + """ + raise NotImplementedError(f"'{self.__class__.__name__}' does not implement method 'getDiscreteRange'") @property - def min(self): + def min(self) -> object: """get the `min` attribute""" return self.minValue - + @property - def max(self): + def max(self) -> object: """get the `max` attribute""" return self.maxValue diff --git a/dbldatagen/datasets/__init__.py b/dbldatagen/datasets/__init__.py index c9339c36..4f4f017c 100644 --- a/dbldatagen/datasets/__init__.py +++ b/dbldatagen/datasets/__init__.py @@ -10,23 +10,23 @@ __all__ = [ - "BasicGeometriesProvider", - "BasicProcessHistorianProvider", - "BasicStockTickerProvider", - "BasicTelematicsProvider", - "BasicUserProvider", - "BenchmarkGroupByProvider", - "DatasetProvider", - "MultiTableSalesOrderProvider", - "MultiTableTelephonyProvider", - "basic_geometries", - "basic_process_historian", - "basic_stock_ticker", - "basic_telematics", - "basic_user", - "benchmark_groupby", - "dataset_definition", - "dataset_provider", - "multi_table_sales_order_provider", - "multi_table_telephony_provider" - ] + "BasicGeometriesProvider", + "BasicProcessHistorianProvider", + "BasicStockTickerProvider", + "BasicTelematicsProvider", + "BasicUserProvider", + "BenchmarkGroupByProvider", + "DatasetProvider", + "MultiTableSalesOrderProvider", + "MultiTableTelephonyProvider", + "basic_geometries", + "basic_process_historian", + "basic_stock_ticker", + "basic_telematics", + "basic_user", + "benchmark_groupby", + "dataset_definition", + "dataset_provider", + "multi_table_sales_order_provider", + "multi_table_telephony_provider", +] diff --git a/dbldatagen/datasets/basic_geometries.py b/dbldatagen/datasets/basic_geometries.py index b24e43b9..83d322c2 100644 --- a/dbldatagen/datasets/basic_geometries.py +++ b/dbldatagen/datasets/basic_geometries.py @@ -8,10 +8,7 @@ from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition -@dataset_definition(name="basic/geometries", - summary="Geometry WKT dataset", - autoRegister=True, - supportsStreaming=True) +@dataset_definition(name="basic/geometries", summary="Geometry WKT dataset", autoRegister=True, supportsStreaming=True) class BasicGeometriesProvider(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider): """ Basic Geometry WKT Dataset @@ -34,6 +31,7 @@ class BasicGeometriesProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset streaming dataframe, and so the flag `supportsStreaming` is set to True. """ + MIN_LOCATION_ID = 1000000 MAX_LOCATION_ID = 9223372036854775807 DEFAULT_MIN_LAT = -90.0 @@ -41,18 +39,26 @@ class BasicGeometriesProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset DEFAULT_MIN_LON = -180.0 DEFAULT_MAX_LON = 180.0 COLUMN_COUNT = 2 - ALLOWED_OPTIONS: ClassVar[list[str]] = [ + ALLOWED_OPTIONS: ClassVar[list[str]] = [ "geometryType", "maxVertices", "minLatitude", "maxLatitude", "minLongitude", "maxLongitude", - "random" + "random", ] @DatasetProvider.allowed_options(options=ALLOWED_OPTIONS) - def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator: + def getTableGenerator( + self, + sparkSession: SparkSession, + *, + tableName: str | None = None, + rows: int = -1, + partitions: int = -1, + **options: dict[str, Any], + ) -> DataGenerator: generateRandom = options.get("random", False) geometryType = options.get("geometryType", "point") maxVertices = options.get("maxVertices", 1 if geometryType == "point" else 3) @@ -67,20 +73,42 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N if partitions is None or partitions < 0: partitions = self.autoComputePartitions(rows, self.COLUMN_COUNT) - df_spec = ( - dg.DataGenerator(sparkSession=sparkSession, name="test_data_set1", rows=rows, - partitions=partitions, randomSeedMethod="hash_fieldname") - .withColumn("location_id", "long", minValue=self.MIN_LOCATION_ID, maxValue=self.MAX_LOCATION_ID, - uniqueValues=rows, random=generateRandom) + df_spec = dg.DataGenerator( + sparkSession=sparkSession, + name="test_data_set1", + rows=rows, + partitions=partitions, + randomSeedMethod="hash_fieldname", + ).withColumn( + "location_id", + "long", + minValue=self.MIN_LOCATION_ID, + maxValue=self.MAX_LOCATION_ID, + uniqueValues=rows, + random=generateRandom, ) if geometryType == "point": if maxVertices > 1: w.warn("Ignoring property maxVertices for point geometries", stacklevel=2) df_spec = ( - df_spec.withColumn("lat", "float", minValue=minLatitude, maxValue=maxLatitude, - step=1e-5, random=generateRandom, omit=True) - .withColumn("lon", "float", minValue=minLongitude, maxValue=maxLongitude, - step=1e-5, random=generateRandom, omit=True) + df_spec.withColumn( + "lat", + "float", + minValue=minLatitude, + maxValue=maxLatitude, + step=1e-5, + random=generateRandom, + omit=True, + ) + .withColumn( + "lon", + "float", + minValue=minLongitude, + maxValue=maxLongitude, + step=1e-5, + random=generateRandom, + omit=True, + ) .withColumn("wkt", "string", expr="concat('POINT(', lon, ' ', lat, ')')") ) elif geometryType == "lineString": @@ -89,37 +117,55 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N w.warn("Parameter maxVertices must be >=2 for 'lineString' geometries; Setting to 2", stacklevel=2) j = 0 while j < maxVertices: - df_spec = ( - df_spec.withColumn(f"lat_{j}", "float", minValue=minLatitude, maxValue=maxLatitude, - step=1e-5, random=generateRandom, omit=True) - .withColumn(f"lon_{j}", "float", minValue=minLongitude, maxValue=maxLongitude, - step=1e-5, random=generateRandom, omit=True) + df_spec = df_spec.withColumn( + f"lat_{j}", + "float", + minValue=minLatitude, + maxValue=maxLatitude, + step=1e-5, + random=generateRandom, + omit=True, + ).withColumn( + f"lon_{j}", + "float", + minValue=minLongitude, + maxValue=maxLongitude, + step=1e-5, + random=generateRandom, + omit=True, ) j = j + 1 concatCoordinatesExpr = [f"concat(lon_{j}, ' ', lat_{j}, ', ')" for j in range(maxVertices)] concatPairsExpr = f"replace(concat('LINESTRING(', {', '.join(concatCoordinatesExpr)}, ')'), ', )', ')')" - df_spec = ( - df_spec.withColumn("wkt", "string", expr=concatPairsExpr) - ) + df_spec = df_spec.withColumn("wkt", "string", expr=concatPairsExpr) elif geometryType == "polygon": if maxVertices < 3: maxVertices = 3 w.warn("Parameter maxVertices must be >=3 for 'polygon' geometries; Setting to 3", stacklevel=2) j = 0 while j < maxVertices: - df_spec = ( - df_spec.withColumn(f"lat_{j}", "float", minValue=minLatitude, maxValue=maxLatitude, - step=1e-5, random=generateRandom, omit=True) - .withColumn(f"lon_{j}", "float", minValue=minLongitude, maxValue=maxLongitude, - step=1e-5, random=generateRandom, omit=True) + df_spec = df_spec.withColumn( + f"lat_{j}", + "float", + minValue=minLatitude, + maxValue=maxLatitude, + step=1e-5, + random=generateRandom, + omit=True, + ).withColumn( + f"lon_{j}", + "float", + minValue=minLongitude, + maxValue=maxLongitude, + step=1e-5, + random=generateRandom, + omit=True, ) j = j + 1 vertexIndices = [*list(range(maxVertices)), 0] concatCoordinatesExpr = [f"concat(lon_{j}, ' ', lat_{j}, ', ')" for j in vertexIndices] concatPairsExpr = f"replace(concat('POLYGON(', {', '.join(concatCoordinatesExpr)}, ')'), ', )', ')')" - df_spec = ( - df_spec.withColumn("wkt", "string", expr=concatPairsExpr) - ) + df_spec = df_spec.withColumn("wkt", "string", expr=concatPairsExpr) else: raise ValueError("geometryType must be 'point', 'lineString', or 'polygon'") diff --git a/dbldatagen/datasets/basic_process_historian.py b/dbldatagen/datasets/basic_process_historian.py index ba7cee13..fc2a8d8e 100644 --- a/dbldatagen/datasets/basic_process_historian.py +++ b/dbldatagen/datasets/basic_process_historian.py @@ -8,10 +8,12 @@ from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition -@dataset_definition(name="basic/process_historian", - summary="Basic Historian Data for Process Manufacturing", - autoRegister=True, - supportsStreaming=True) +@dataset_definition( + name="basic/process_historian", + summary="Basic Historian Data for Process Manufacturing", + autoRegister=True, + supportsStreaming=True, +) class BasicProcessHistorianProvider(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider): """ Basic Process Historian Dataset @@ -40,6 +42,7 @@ class BasicProcessHistorianProvider(DatasetProvider.NoAssociatedDatasetsMixin, D streaming dataframe, and so the flag `supportsStreaming` is set to True. """ + MIN_DEVICE_ID = 0x100000000 MAX_DEVICE_ID = 9223372036854775807 MIN_PROPERTY_VALUE = 50.0 @@ -57,12 +60,19 @@ class BasicProcessHistorianProvider(DatasetProvider.NoAssociatedDatasetsMixin, D "startTimestamp", "endTimestamp", "dataQualityRatios", - "random" + "random", ] @DatasetProvider.allowed_options(options=ALLOWED_OPTIONS) - def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator: - + def getTableGenerator( + self, + sparkSession: SparkSession, + *, + tableName: str | None = None, + rows: int = -1, + partitions: int = -1, + **options: dict[str, Any], + ) -> DataGenerator: generateRandom = options.get("random", False) numDevices = options.get("numDevices", self.DEFAULT_NUM_DEVICES) @@ -81,36 +91,51 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N tag_names = [f"HEX-{str(j).zfill(int(np.ceil(np.log10(numTags))))}_INLET_TMP" for j in range(numTags)] plant_ids = [f"PLANT-{str(j).zfill(int(np.ceil(np.log10(numPlants))))}" for j in range(numPlants)] testDataSpec = ( - dg.DataGenerator(sparkSession, name="process_historian_data", rows=rows, - partitions=partitions, - randomSeedMethod="hash_fieldname") - .withColumn("internal_device_id", "long", minValue=self.MIN_DEVICE_ID, maxValue=self.MAX_DEVICE_ID, - uniqueValues=numDevices, omit=True, baseColumnType="hash") + dg.DataGenerator( + sparkSession, + name="process_historian_data", + rows=rows, + partitions=partitions, + randomSeedMethod="hash_fieldname", + ) + .withColumn( + "internal_device_id", + "long", + minValue=self.MIN_DEVICE_ID, + maxValue=self.MAX_DEVICE_ID, + uniqueValues=numDevices, + omit=True, + baseColumnType="hash", + ) .withColumn("device_id", "string", format="0x%09x", baseColumn="internal_device_id") .withColumn("plant_id", "string", values=plant_ids, baseColumn="internal_device_id") .withColumn("tag_name", "string", values=tag_names, baseColumn="internal_device_id") - .withColumn("ts", "timestamp", begin=startTimestamp, end=endTimestamp, - interval="1 second", random=generateRandom) - .withColumn("value", "float", minValue=self.MIN_PROPERTY_VALUE, maxValue=self.MAX_PROPERTY_VALUE, - step=1e-3, random=generateRandom) + .withColumn( + "ts", "timestamp", begin=startTimestamp, end=endTimestamp, interval="1 second", random=generateRandom + ) + .withColumn( + "value", + "float", + minValue=self.MIN_PROPERTY_VALUE, + maxValue=self.MAX_PROPERTY_VALUE, + step=1e-3, + random=generateRandom, + ) .withColumn("engineering_units", "string", expr="'Deg.F'") ) # Add the data quality columns if they were provided if dataQualityRatios is not None: if "pctQuestionable" in dataQualityRatios: testDataSpec = testDataSpec.withColumn( - "is_questionable", "boolean", - expr=f"rand() < {dataQualityRatios['pctQuestionable']}" + "is_questionable", "boolean", expr=f"rand() < {dataQualityRatios['pctQuestionable']}" ) if "pctSubstituted" in dataQualityRatios: testDataSpec = testDataSpec.withColumn( - "is_substituted", "boolean", - expr=f"rand() < {dataQualityRatios['pctSubstituted']}" + "is_substituted", "boolean", expr=f"rand() < {dataQualityRatios['pctSubstituted']}" ) if "pctAnnotated" in dataQualityRatios: testDataSpec = testDataSpec.withColumn( - "is_annotated", "boolean", - expr=f"rand() < {dataQualityRatios['pctAnnotated']}" + "is_annotated", "boolean", expr=f"rand() < {dataQualityRatios['pctAnnotated']}" ) return testDataSpec diff --git a/dbldatagen/datasets/basic_stock_ticker.py b/dbldatagen/datasets/basic_stock_ticker.py index 74cfe3f9..7f360d28 100644 --- a/dbldatagen/datasets/basic_stock_ticker.py +++ b/dbldatagen/datasets/basic_stock_ticker.py @@ -8,10 +8,9 @@ from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition -@dataset_definition(name="basic/stock_ticker", - summary="Stock ticker dataset", - autoRegister=True, - supportsStreaming=True) +@dataset_definition( + name="basic/stock_ticker", summary="Stock ticker dataset", autoRegister=True, supportsStreaming=True +) class BasicStockTickerProvider(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider): """ Basic Stock Ticker Dataset @@ -32,16 +31,22 @@ class BasicStockTickerProvider(DatasetProvider.NoAssociatedDatasetsMixin, Datase Note that this dataset does not use any features that would prevent it from being used as a source for a streaming dataframe, and so the flag `supportsStreaming` is set to True. """ + DEFAULT_NUM_SYMBOLS = 100 DEFAULT_START_DATE = "2024-10-01" COLUMN_COUNT = 8 - ALLOWED_OPTIONS: ClassVar[list[str]] = [ - "numSymbols", - "startDate" - ] + ALLOWED_OPTIONS: ClassVar[list[str]] = ["numSymbols", "startDate"] @DatasetProvider.allowed_options(options=ALLOWED_OPTIONS) - def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: object) -> DataGenerator: + def getTableGenerator( + self, + sparkSession: SparkSession, + *, + tableName: str | None = None, + rows: int = -1, + partitions: int = -1, + **options: object, + ) -> DataGenerator: numSymbols = options.get("numSymbols", self.DEFAULT_NUM_SYMBOLS) startDate = options.get("startDate", self.DEFAULT_START_DATE) @@ -55,44 +60,73 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N raise ValueError("'numSymbols' must be > 0") df_spec = ( - dg.DataGenerator(sparkSession=sparkSession, rows=rows, - partitions=partitions, randomSeedMethod="hash_fieldname") + dg.DataGenerator( + sparkSession=sparkSession, rows=rows, partitions=partitions, randomSeedMethod="hash_fieldname" + ) .withColumn("symbol_id", "long", minValue=676, maxValue=676 + numSymbols - 1) - .withColumn("rand_value", "float", minValue=0.0, maxValue=1.0, step=0.1, - baseColumn="symbol_id", omit=True) - .withColumn("symbol", "string", - expr="""concat_ws('', transform(split(conv(symbol_id, 10, 26), ''), - x -> case when ascii(x) < 10 then char(ascii(x) - 48 + 65) else char(ascii(x) + 10) end))""") + .withColumn("rand_value", "float", minValue=0.0, maxValue=1.0, step=0.1, baseColumn="symbol_id", omit=True) + .withColumn( + "symbol", + "string", + expr="""concat_ws('', transform(split(conv(symbol_id, 10, 26), ''), + x -> case when ascii(x) < 10 then char(ascii(x) - 48 + 65) else char(ascii(x) + 10) end))""", + ) .withColumn("days_from_start_date", "int", expr=f"floor(try_divide(id, {numSymbols}))", omit=True) .withColumn("post_date", "date", expr=f"date_add(cast('{startDate}' as date), days_from_start_date)") - .withColumn("start_value", "decimal(11,2)", - values=[1.0 + 199.0 * random() for _ in range(max(1, int(numSymbols / 10)))], omit=True) - .withColumn("growth_rate", "float", values=[-0.1 + 0.35 * random() for _ in range(max(1, int(numSymbols / 10)))], - baseColumn="symbol_id") - .withColumn("volatility", "float", values=[0.0075 * random() for _ in range(max(1, int(numSymbols / 10)))], - baseColumn="symbol_id", omit=True) - .withColumn("prev_modifier_sign", "float", - expr=f"case when sin((id - {numSymbols}) % 17) > 0 then -1.0 else 1.0 end""", - omit=True) - .withColumn("modifier_sign", "float", - expr="case when sin(id % 17) > 0 then -1.0 else 1.0 end", - omit=True) - .withColumn("open_base", "decimal(11,2)", - expr=f"""start_value + .withColumn( + "start_value", + "decimal(11,2)", + values=[1.0 + 199.0 * random() for _ in range(max(1, int(numSymbols / 10)))], + omit=True, + ) + .withColumn( + "growth_rate", + "float", + values=[-0.1 + 0.35 * random() for _ in range(max(1, int(numSymbols / 10)))], + baseColumn="symbol_id", + ) + .withColumn( + "volatility", + "float", + values=[0.0075 * random() for _ in range(max(1, int(numSymbols / 10)))], + baseColumn="symbol_id", + omit=True, + ) + .withColumn( + "prev_modifier_sign", + "float", + expr=f"case when sin((id - {numSymbols}) % 17) > 0 then -1.0 else 1.0 end" "", + omit=True, + ) + .withColumn("modifier_sign", "float", expr="case when sin(id % 17) > 0 then -1.0 else 1.0 end", omit=True) + .withColumn( + "open_base", + "decimal(11,2)", + expr=f"""start_value + (volatility * prev_modifier_sign * start_value * sin((id - {numSymbols}) % 17)) + (growth_rate * start_value * try_divide(days_from_start_date - 1, 365))""", - omit=True) - .withColumn("close_base", "decimal(11,2)", - expr="""start_value + omit=True, + ) + .withColumn( + "close_base", + "decimal(11,2)", + expr="""start_value + (volatility * start_value * sin(id % 17)) + (growth_rate * start_value * try_divide(days_from_start_date, 365))""", - omit=True) - .withColumn("high_base", "decimal(11,2)", - expr="greatest(open_base, close_base) + rand() * volatility * open_base", - omit=True) - .withColumn("low_base", "decimal(11,2)", - expr="least(open_base, close_base) - rand() * volatility * open_base", - omit=True) + omit=True, + ) + .withColumn( + "high_base", + "decimal(11,2)", + expr="greatest(open_base, close_base) + rand() * volatility * open_base", + omit=True, + ) + .withColumn( + "low_base", + "decimal(11,2)", + expr="least(open_base, close_base) - rand() * volatility * open_base", + omit=True, + ) .withColumn("open", "decimal(11,2)", expr="greatest(open_base, 0.0)") .withColumn("close", "decimal(11,2)", expr="greatest(close_base, 0.0)") .withColumn("high", "decimal(11,2)", expr="greatest(high_base, 0.0)") diff --git a/dbldatagen/datasets/basic_telematics.py b/dbldatagen/datasets/basic_telematics.py index 27260219..5ac5662e 100644 --- a/dbldatagen/datasets/basic_telematics.py +++ b/dbldatagen/datasets/basic_telematics.py @@ -8,10 +8,9 @@ from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition -@dataset_definition(name="basic/telematics", - summary="Telematics dataset for GPS tracking", - autoRegister=True, - supportsStreaming=True) +@dataset_definition( + name="basic/telematics", summary="Telematics dataset for GPS tracking", autoRegister=True, supportsStreaming=True +) class BasicTelematicsProvider(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider): """ Basic Telematics Dataset @@ -39,6 +38,7 @@ class BasicTelematicsProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset streaming dataframe, and so the flag `supportsStreaming` is set to True. """ + MIN_DEVICE_ID = 1000000 MAX_DEVICE_ID = 9223372036854775807 DEFAULT_NUM_DEVICES = 1000 @@ -58,11 +58,19 @@ class BasicTelematicsProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset "minLon", "maxLon", "generateWkt", - "random" + "random", ] @DatasetProvider.allowed_options(options=ALLOWED_OPTIONS) - def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator: + def getTableGenerator( + self, + sparkSession: SparkSession, + *, + tableName: str | None = None, + rows: int = -1, + partitions: int = -1, + **options: dict[str, Any], + ) -> DataGenerator: generateRandom = options.get("random", False) numDevices = options.get("numDevices", self.DEFAULT_NUM_DEVICES) @@ -110,24 +118,42 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N (minLat, maxLat) = (maxLat, minLat) w.warn("Received minLat > maxLat; Swapping values", stacklevel=2) df_spec = ( - dg.DataGenerator(sparkSession=sparkSession, rows=rows, - partitions=partitions, randomSeedMethod="hash_fieldname") - .withColumn("device_id", "long", minValue=self.MIN_DEVICE_ID, maxValue=self.MAX_DEVICE_ID, - uniqueValues=numDevices, random=generateRandom) - .withColumn("ts", "timestamp", begin=startTimestamp, end=endTimestamp, - interval="1 second", random=generateRandom) - .withColumn("base_lat", "float", minValue=minLat, maxValue=maxLat, step=0.5, - baseColumn="device_id", omit=True) - .withColumn("base_lon", "float", minValue=minLon, maxValue=maxLon, step=0.5, - baseColumn="device_id", omit=True) + dg.DataGenerator( + sparkSession=sparkSession, rows=rows, partitions=partitions, randomSeedMethod="hash_fieldname" + ) + .withColumn( + "device_id", + "long", + minValue=self.MIN_DEVICE_ID, + maxValue=self.MAX_DEVICE_ID, + uniqueValues=numDevices, + random=generateRandom, + ) + .withColumn( + "ts", "timestamp", begin=startTimestamp, end=endTimestamp, interval="1 second", random=generateRandom + ) + .withColumn( + "base_lat", "float", minValue=minLat, maxValue=maxLat, step=0.5, baseColumn="device_id", omit=True + ) + .withColumn( + "base_lon", "float", minValue=minLon, maxValue=maxLon, step=0.5, baseColumn="device_id", omit=True + ) .withColumn("unv_lat", "float", expr="base_lat + (0.5-format_number(rand(), 3))*1e-3", omit=True) .withColumn("unv_lon", "float", expr="base_lon + (0.5-format_number(rand(), 3))*1e-3", omit=True) - .withColumn("lat", "float", expr=f"""CASE WHEN unv_lat > {maxLat} THEN {maxLat} + .withColumn( + "lat", + "float", + expr=f"""CASE WHEN unv_lat > {maxLat} THEN {maxLat} ELSE CASE WHEN unv_lat < {minLat} THEN {minLat} - ELSE unv_lat END END""") - .withColumn("lon", "float", expr=f"""CASE WHEN unv_lon > {maxLon} THEN {maxLon} + ELSE unv_lat END END""", + ) + .withColumn( + "lon", + "float", + expr=f"""CASE WHEN unv_lon > {maxLon} THEN {maxLon} ELSE CASE WHEN unv_lon < {minLon} THEN {minLon} - ELSE unv_lon END END""") + ELSE unv_lon END END""", + ) .withColumn("heading", "integer", minValue=0, maxValue=359, step=1, random=generateRandom) .withColumn("wkt", "string", expr="concat('POINT(', lon, ' ', lat, ')')", omit=not generateWkt) ) diff --git a/dbldatagen/datasets/basic_user.py b/dbldatagen/datasets/basic_user.py index c522b865..517bb0dc 100644 --- a/dbldatagen/datasets/basic_user.py +++ b/dbldatagen/datasets/basic_user.py @@ -29,11 +29,20 @@ class BasicUserProvider(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvid streaming dataframe, and so the flag `supportsStreaming` is set to True. """ + MAX_LONG = 9223372036854775807 COLUMN_COUNT = 5 @DatasetProvider.allowed_options(options=["random", "dummyValues"]) - def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator: + def getTableGenerator( + self, + sparkSession: SparkSession, + *, + tableName: str | None = None, + rows: int = -1, + partitions: int = -1, + **options: dict[str, Any], + ) -> DataGenerator: generateRandom = options.get("random", False) dummyValues = options.get("dummyValues", 0) @@ -45,23 +54,21 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N assert tableName is None or tableName == DatasetProvider.DEFAULT_TABLE_NAME, "Invalid table name" df_spec = ( - dg.DataGenerator(sparkSession=sparkSession, rows=rows, - partitions=partitions, - randomSeedMethod="hash_fieldname") + dg.DataGenerator( + sparkSession=sparkSession, rows=rows, partitions=partitions, randomSeedMethod="hash_fieldname" + ) .withColumn("customer_id", "long", minValue=1000000, maxValue=self.MAX_LONG, random=generateRandom) - .withColumn("name", "string", - template=r"\w \w|\w \w \w", random=generateRandom) - .withColumn("email", "string", - template=r"\w.\w@\w.com|\w@\w.co.u\k", random=generateRandom) - .withColumn("ip_addr", "string", - template=r"\n.\n.\n.\n", random=generateRandom) - .withColumn("phone", "string", - template=r"(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd", - random=generateRandom) + .withColumn("name", "string", template=r"\w \w|\w \w \w", random=generateRandom) + .withColumn("email", "string", template=r"\w.\w@\w.com|\w@\w.co.u\k", random=generateRandom) + .withColumn("ip_addr", "string", template=r"\n.\n.\n.\n", random=generateRandom) + .withColumn( + "phone", "string", template=r"(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd", random=generateRandom + ) ) if dummyValues > 0: - df_spec = df_spec.withColumn("dummy", "long", random=True, numColumns=dummyValues, - minValue=1, maxValue=self.MAX_LONG) + df_spec = df_spec.withColumn( + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG + ) return df_spec diff --git a/dbldatagen/datasets/benchmark_groupby.py b/dbldatagen/datasets/benchmark_groupby.py index 24bd8b35..9c251dae 100644 --- a/dbldatagen/datasets/benchmark_groupby.py +++ b/dbldatagen/datasets/benchmark_groupby.py @@ -8,10 +8,12 @@ from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition -@dataset_definition(name="benchmark/groupby", - summary="Benchmarking dataset for GROUP BY queries in various database systems", - autoRegister=True, - supportsStreaming=True) +@dataset_definition( + name="benchmark/groupby", + summary="Benchmarking dataset for GROUP BY queries in various database systems", + autoRegister=True, + supportsStreaming=True, +) class BenchmarkGroupByProvider(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider): """ Grouping Benchmark Dataset @@ -34,6 +36,7 @@ class BenchmarkGroupByProvider(DatasetProvider.NoAssociatedDatasetsMixin, Datase streaming dataframe, and so the flag `supportsStreaming` is set to True. """ + MAX_LONG = 9223372036854775807 DEFAULT_NUM_GROUPS = 100 DEFAULT_PCT_NULLS = 0.0 @@ -41,7 +44,15 @@ class BenchmarkGroupByProvider(DatasetProvider.NoAssociatedDatasetsMixin, Datase ALLOWED_OPTIONS: ClassVar[list[str]] = ["groups", "percentNulls", "rows", "partitions", "tableName", "random"] @DatasetProvider.allowed_options(options=ALLOWED_OPTIONS) - def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator: + def getTableGenerator( + self, + sparkSession: SparkSession, + *, + tableName: str | None = None, + rows: int = -1, + partitions: int = -1, + **options: dict[str, Any], + ) -> DataGenerator: generateRandom = options.get("random", False) groups = options.get("groups", self.DEFAULT_NUM_GROUPS) @@ -71,26 +82,48 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N assert tableName is None or tableName == DatasetProvider.DEFAULT_TABLE_NAME, "Invalid table name" df_spec = ( - dg.DataGenerator(sparkSession=sparkSession, rows=rows, - partitions=partitions, - randomSeedMethod="hash_fieldname") - .withColumn("base1", "integer", minValue=1, maxValue=groups, - uniqueValues=groups, random=generateRandom, omit=True) - .withColumn("base2", "integer", minValue=1, maxValue=groups, - uniqueValues=groups, random=generateRandom, omit=True) - .withColumn("base3", "integer", minValue=1, maxValue=(1 + int(rows / groups)), - uniqueValues=(1 + int(rows / groups)), random=generateRandom, omit=True) + dg.DataGenerator( + sparkSession=sparkSession, rows=rows, partitions=partitions, randomSeedMethod="hash_fieldname" + ) + .withColumn( + "base1", "integer", minValue=1, maxValue=groups, uniqueValues=groups, random=generateRandom, omit=True + ) + .withColumn( + "base2", "integer", minValue=1, maxValue=groups, uniqueValues=groups, random=generateRandom, omit=True + ) + .withColumn( + "base3", + "integer", + minValue=1, + maxValue=(1 + int(rows / groups)), + uniqueValues=(1 + int(rows / groups)), + random=generateRandom, + omit=True, + ) .withColumn("id1", "string", baseColumn="base1", format="id%03d", percentNulls=percentNulls) .withColumn("id2", "string", baseColumn="base2", format="id%03d", percentNulls=percentNulls) .withColumn("id3", "string", baseColumn="base3", format="id%010d", percentNulls=percentNulls) .withColumn("id4", "integer", minValue=1, maxValue=groups, random=generateRandom, percentNulls=percentNulls) .withColumn("id5", "integer", minValue=1, maxValue=groups, random=generateRandom, percentNulls=percentNulls) - .withColumn("id6", "integer", minValue=1, maxValue=(1 + int(rows / groups)), random=generateRandom, - percentNulls=percentNulls) + .withColumn( + "id6", + "integer", + minValue=1, + maxValue=(1 + int(rows / groups)), + random=generateRandom, + percentNulls=percentNulls, + ) .withColumn("v1", "integer", minValue=1, maxValue=5, random=generateRandom, percentNulls=percentNulls) .withColumn("v2", "integer", minValue=1, maxValue=15, random=generateRandom, percentNulls=percentNulls) - .withColumn("v3", "decimal(9,6)", minValue=0.0, maxValue=100.0, - step=1e-6, random=generateRandom, percentNulls=percentNulls) + .withColumn( + "v3", + "decimal(9,6)", + minValue=0.0, + maxValue=100.0, + step=1e-6, + random=generateRandom, + percentNulls=percentNulls, + ) ) return df_spec diff --git a/dbldatagen/datasets/dataset_provider.py b/dbldatagen/datasets/dataset_provider.py index d67c440d..b3e8dfaa 100644 --- a/dbldatagen/datasets/dataset_provider.py +++ b/dbldatagen/datasets/dataset_provider.py @@ -58,6 +58,7 @@ class DatasetProvider(ABC): By default, all DatasetProvider classes should support batch usage. If a dataset provider supports streaming usage, the flag `supportsStreaming` should be set to `True` in the decorator. """ + DEFAULT_TABLE_NAME = "main" DEFAULT_ROWS = 100_000 DEFAULT_PARTITIONS = 4 @@ -75,7 +76,7 @@ class DatasetProvider(ABC): @dataclass class DatasetDefinition: - """ Dataset Definition class - stores the attributes related to the dataset for use by the implementation + """Dataset Definition class - stores the attributes related to the dataset for use by the implementation of the decorator. This stores the name of the dataset (e.g. `basic/user`), the list of tables provided by the dataset, @@ -85,6 +86,7 @@ class DatasetDefinition: It also allows specification of supporting tables which are tables computed from existing dataframes that can be provided by the dataset provider """ + name: str tables: list[str] primaryTable: str @@ -102,18 +104,20 @@ def isValidDataProviderType(cls, candidateDataProvider: type) -> bool: :return: True if valid DatasetProvider type, False otherwise """ - return (candidateDataProvider is not None and - isinstance(candidateDataProvider, type) and - issubclass(candidateDataProvider, cls)) + return ( + candidateDataProvider is not None + and isinstance(candidateDataProvider, type) + and issubclass(candidateDataProvider, cls) + ) @classmethod def getDatasetDefinition(cls) -> DatasetDefinition: - """ Get the dataset definition for the class """ + """Get the dataset definition for the class""" return cls._DATASET_DEFINITION @classmethod def getDatasetTables(cls) -> list[str]: - """ Get the dataset tables list for the class """ + """Get the dataset tables list for the class""" datasetDefinition = cls.getDatasetDefinition() if datasetDefinition is None or datasetDefinition.tables is None: @@ -123,7 +127,7 @@ def getDatasetTables(cls) -> list[str]: @classmethod def registerDataset(cls, datasetProvider: type) -> None: - """ Register the dataset provider type using metadata defined in the dataset provider + """Register the dataset provider type using metadata defined in the dataset provider :param datasetProvider: Dataset provider class :return: None @@ -143,14 +147,13 @@ def registerDataset(cls, datasetProvider: type) -> None: datasetDefinition = datasetProvider.getDatasetDefinition() - assert isinstance(datasetDefinition, cls.DatasetDefinition), \ - "retrieved datasetDefinition must be an instance of DatasetDefinition" + assert isinstance( + datasetDefinition, cls.DatasetDefinition + ), "retrieved datasetDefinition must be an instance of DatasetDefinition" - assert datasetDefinition.name is not None, \ - "datasetDefinition must contain a name for the data set" + assert datasetDefinition.name is not None, "datasetDefinition must contain a name for the data set" - assert issubclass(datasetDefinition.providerClass, cls), \ - "datasetClass must be a subclass of DatasetProvider" + assert issubclass(datasetDefinition.providerClass, cls), "datasetClass must be a subclass of DatasetProvider" if datasetDefinition.name in cls._registeredDatasetsMetadata: raise ValueError(f"Dataset provider is already registered for name `{datasetDefinition.name}`") @@ -160,7 +163,7 @@ def registerDataset(cls, datasetProvider: type) -> None: @classmethod def unregisterDataset(cls, name: str) -> None: - """ Unregister the dataset with the specified name + """Unregister the dataset with the specified name :param name: Name of the dataset to unregister """ @@ -180,7 +183,7 @@ def getRegisteredDatasets(cls) -> dict[str, DatasetDefinition]: return cls._registeredDatasetsMetadata @classmethod - def getRegisteredDatasetsVersion(cls) -> int : + def getRegisteredDatasetsVersion(cls) -> int: """ Get the registered datasets version indicator :return: A dictionary of registered datasets @@ -188,7 +191,15 @@ def getRegisteredDatasetsVersion(cls) -> int : return cls._registeredDatasetsVersion @abstractmethod - def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator: + def getTableGenerator( + self, + sparkSession: SparkSession, + *, + tableName: str | None = None, + rows: int = -1, + partitions: int = -1, + **options: dict[str, Any], + ) -> DataGenerator: """Gets data generation instance that will produce table for named table :param sparkSession: Spark session to use @@ -207,8 +218,15 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N raise NotImplementedError("Base data provider does not provide any table generation specifications!") @abstractmethod - def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str | None=None, rows: int=-1, partitions: int=-1, - **options: dict[str, Any]) -> DataGenerator: + def getAssociatedDataset( + self, + sparkSession: SparkSession, + *, + tableName: str | None = None, + rows: int = -1, + partitions: int = -1, + **options: dict[str, Any], + ) -> DataGenerator: """ Gets associated datasets that are used in conjunction with the provider datasets. These may be associated lookup tables, tables that execute benchmarks or exercise key features as part of @@ -230,20 +248,23 @@ def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str | N raise NotImplementedError("Base data provider does not produce any supporting tables!") @staticmethod - def allowed_options(options: list[str]|None =None) -> Callable[[Callable], Callable]: - """ Decorator to enforce allowed options + def allowed_options(options: list[str] | None = None) -> Callable[[Callable], Callable]: + """Decorator to enforce allowed options - Used to document and enforce what options are allowed for each dataset provider implementation - If the signature of the getTableGenerator method changes, change the DEFAULT_OPTIONS constant - to include options that are always allowed + Used to document and enforce what options are allowed for each dataset provider implementation + If the signature of the getTableGenerator method changes, change the DEFAULT_OPTIONS constant + to include options that are always allowed """ DEFAULT_OPTIONS = ["sparkSession", "tableName", "rows", "partitions"] def decorator(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs) -> Callable: - bad_options = [keyword_arg for keyword_arg in kwargs - if keyword_arg not in DEFAULT_OPTIONS and keyword_arg not in options] + bad_options = [ + keyword_arg + for keyword_arg in kwargs + if keyword_arg not in DEFAULT_OPTIONS and keyword_arg not in options + ] if len(bad_options) > 0: errorMessage = f"""The following options are unsupported by provider: [{",".join(bad_options)}]""" @@ -256,7 +277,7 @@ def wrapper(*args, **kwargs) -> Callable: return decorator def checkOptions(self, options: dict[str, Any], allowedOptions: list[str]) -> DatasetDefinition: - """ Check that options are valid + """Check that options are valid :param options: options to check as dict :param allowedOptions: allowed options as list of strings @@ -268,7 +289,7 @@ def checkOptions(self, options: dict[str, Any], allowedOptions: list[str]) -> Da return self def autoComputePartitions(self, rows: int, columns: int) -> int: - """ Compute the number of partitions based on rows and columns + """Compute the number of partitions based on rows and columns :param rows: number of rows :param columns: number of columns @@ -285,29 +306,47 @@ def autoComputePartitions(self, rows: int, columns: int) -> int: return max(self.DEFAULT_PARTITIONS, int(math.log(rows / 350_000) * max(1, math.log(columns)))) class NoAssociatedDatasetsMixin(ABC): # noqa: B024 - """ Use this mixin to provide default implementation for data provider when it does not provide - any associated datasets + """Use this mixin to provide default implementation for data provider when it does not provide + any associated datasets """ - def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int =-1, - **options: dict[str, Any]) -> DataGenerator: + + def getAssociatedDataset( + self, + sparkSession: SparkSession, + *, + tableName: str | None = None, + rows: int = -1, + partitions: int = -1, + **options: dict[str, Any], + ) -> DataGenerator: raise NotImplementedError("Data provider does not produce any associated datasets!") class DatasetDecoratorUtils: - """ Defines the predefined_dataset decorator + """Defines the predefined_dataset decorator - :param cls: target class to apply decorator to - :param name: name of the dataset - :param tables: list of tables produced by the dataset provider, if None, defaults to [ DEFAULT_TABLE_NAME ] - :param primaryTable: primary table provided by dataset. Defaults to first table of table list - :param summary: Summary information for the dataset. If None, will be derived from target class name - :param description: Detailed description of the class. If None, will use the target class doc string - :param associatedDatasets: list of associated datasets produced by the dataset provider - :param supportsStreaming: Whether data set can be used in streaming scenarios + :param cls: target class to apply decorator to + :param name: name of the dataset + :param tables: list of tables produced by the dataset provider, if None, defaults to [ DEFAULT_TABLE_NAME ] + :param primaryTable: primary table provided by dataset. Defaults to first table of table list + :param summary: Summary information for the dataset. If None, will be derived from target class name + :param description: Detailed description of the class. If None, will use the target class doc string + :param associatedDatasets: list of associated datasets produced by the dataset provider + :param supportsStreaming: Whether data set can be used in streaming scenarios """ - def __init__(self, cls: type|None =None, *, name: str|None =None, tables: list[str]|None =None, primaryTable: str|None =None, summary: str|None =None, description: str|None =None, - associatedDatasets: list[str]|None =None, supportsStreaming: bool =False) -> None: + def __init__( + self, + cls: type | None = None, + *, + name: str | None = None, + tables: list[str] | None = None, + primaryTable: str | None = None, + summary: str | None = None, + description: str | None = None, + associatedDatasets: list[str] | None = None, + supportsStreaming: bool = False, + ) -> None: self._targetCls = cls # compute the data set provider name if not provided. @@ -338,8 +377,7 @@ def __init__(self, cls: type|None =None, *, name: str|None =None, tables: list[s generated_description = [ f"The datasetProvider '{cls.__name__}' provides a data spec for the '{self._datasetName}' dataset", "", # empty line - f"Summary: {self._summary}" - "", # empty line + f"Summary: {self._summary}" "", # empty line f"Tables generators provided: {', '.join(self._tables)}", "", # empty line f"Primary table: {self._primaryTable}", @@ -351,8 +389,8 @@ def __init__(self, cls: type|None =None, *, name: str|None =None, tables: list[s ] self._description = "\n".join(generated_description) - def mkClass(self, autoRegister: bool =False) -> type: - """ make the modified class for the Data Provider + def mkClass(self, autoRegister: bool = False) -> type: + """make the modified class for the Data Provider Applies the decorator args as a metadata object on the class. This is done at the class level as there is no instance of the target class at this point. @@ -362,15 +400,16 @@ def mkClass(self, autoRegister: bool =False) -> type: if self._targetCls is not None: # if self._targetCls is not None and (isinstance(self._targetCls, DatasetProvider) or # issubclass(self._targetCls, DatasetProvider)): - dataset_desc = DatasetProvider.DatasetDefinition(name=self._datasetName, - tables=self._tables, - primaryTable=self._primaryTable, - summary=self._summary, - description=self._description, - supportsStreaming=self._supportsStreaming, - providerClass=self._targetCls, - associatedDatasets=self._associatedDatasets - ) + dataset_desc = DatasetProvider.DatasetDefinition( + name=self._datasetName, + tables=self._tables, + primaryTable=self._primaryTable, + summary=self._summary, + description=self._description, + supportsStreaming=self._supportsStreaming, + providerClass=self._targetCls, + associatedDatasets=self._associatedDatasets, + ) self._targetCls._DATASET_DEFINITION = dataset_desc retval = self._targetCls else: @@ -382,8 +421,8 @@ def mkClass(self, autoRegister: bool =False) -> type: return retval -def dataset_definition(cls: type|None =None, *args: object, autoRegister: bool =False, **kwargs: object) -> type: - """ decorator to define standard dataset definition +def dataset_definition(cls: type | None = None, *args: object, autoRegister: bool = False, **kwargs: object) -> type: + """decorator to define standard dataset definition This is intended to be applied classes derived from DatasetProvider to simplify the implementation of the predefined datasets. @@ -415,8 +454,8 @@ class X(DatasetProvider) """ - def inner_wrapper(inner_cls: type|None =None, *inner_args: object, **inner_kwargs: object) -> type: - """ The inner wrapper function is used to handle the case where the decorator is used with arguments. + def inner_wrapper(inner_cls: type | None = None, *inner_args: object, **inner_kwargs: object) -> type: + """The inner wrapper function is used to handle the case where the decorator is used with arguments. It defers the application of the decorator to the target class until the target class is available. :param inner_cls: inner class object @@ -426,8 +465,9 @@ def inner_wrapper(inner_cls: type|None =None, *inner_args: object, **inner_kwarg :return: Returns the target class object """ try: - assert DatasetProvider.isValidDataProviderType(inner_cls), \ - f"Target class of decorator ({inner_cls}) must inherit from DataProvider" + assert DatasetProvider.isValidDataProviderType( + inner_cls + ), f"Target class of decorator ({inner_cls}) must inherit from DataProvider" return DatasetProvider.DatasetDecoratorUtils(inner_cls, *args, **kwargs).mkClass(autoRegister) except Exception as exc: raise TypeError(f"Invalid decorator usage: {exc}") from exc @@ -438,8 +478,9 @@ def inner_wrapper(inner_cls: type|None =None, *inner_args: object, **inner_kwarg if cls is not None: # handle decorator syntax with no arguments # when no arguments are provided to the decorator, the only argument passed is an implicit class object - assert DatasetProvider.isValidDataProviderType(cls), \ - f"Target class of decorator ({cls}) must inherit from DataProvider" + assert DatasetProvider.isValidDataProviderType( + cls + ), f"Target class of decorator ({cls}) must inherit from DataProvider" return DatasetProvider.DatasetDecoratorUtils(cls, *args, **kwargs).mkClass(autoRegister) else: # handle decorator syntax with arguments - here we simply return the inner wrapper function diff --git a/dbldatagen/datasets/multi_table_sales_order_provider.py b/dbldatagen/datasets/multi_table_sales_order_provider.py index f6b0273d..4ce889f6 100644 --- a/dbldatagen/datasets/multi_table_sales_order_provider.py +++ b/dbldatagen/datasets/multi_table_sales_order_provider.py @@ -8,13 +8,24 @@ from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition -@dataset_definition(name="multi_table/sales_order", summary="Multi-table sales order dataset", supportsStreaming=True, - autoRegister=True, - tables=["customers", "carriers", "catalog_items", "base_orders", "base_order_line_items", - "base_order_shipments", "base_invoices"], - associatedDatasets=["orders", "order_line_items", "order_shipments", "invoices"]) +@dataset_definition( + name="multi_table/sales_order", + summary="Multi-table sales order dataset", + supportsStreaming=True, + autoRegister=True, + tables=[ + "customers", + "carriers", + "catalog_items", + "base_orders", + "base_order_line_items", + "base_order_shipments", + "base_invoices", + ], + associatedDatasets=["orders", "order_line_items", "order_shipments", "invoices"], +) class MultiTableSalesOrderProvider(DatasetProvider): - """ Generates a multi-table sales order scenario + """Generates a multi-table sales order scenario See [https://databrickslabs.github.io/dbldatagen/public_docs/multi_table_data.html] @@ -51,6 +62,7 @@ class MultiTableSalesOrderProvider(DatasetProvider): While it is possible to specify the number of rows explicitly when getting each table generator, the default will be to compute the number of rows from these options. """ + MAX_LONG = 9223372036854775807 DEFAULT_NUM_CUSTOMERS = 1_000 DEFAULT_NUM_CARRIERS = 100 @@ -67,7 +79,9 @@ class MultiTableSalesOrderProvider(DatasetProvider): SHIPMENT_MIN_VALUE = 10_000_000 INVOICE_MIN_VALUE = 1_000_000 - def getCustomers(self, sparkSession: SparkSession, *, rows: int, partitions: int, numCustomers: int, dummyValues: int) -> DataGenerator: + def getCustomers( + self, sparkSession: SparkSession, *, rows: int, partitions: int, numCustomers: int, dummyValues: int + ) -> DataGenerator: # Validate the options: if numCustomers is None or numCustomers < 0: numCustomers = self.DEFAULT_NUM_CUSTOMERS @@ -85,28 +99,39 @@ def getCustomers(self, sparkSession: SparkSession, *, rows: int, partitions: int .withColumn("num_employees", "integer", minValue=1, maxValue=10_000, random=True) .withColumn("region", "string", values=["AMER", "EMEA", "APAC", "NONE"], random=True) .withColumn("phone_number", "string", template="ddd-ddd-dddd") - .withColumn("email_user_name", "string", - values=["billing", "procurement", "office", "purchasing", "buyer"], omit=True) + .withColumn( + "email_user_name", + "string", + values=["billing", "procurement", "office", "purchasing", "buyer"], + omit=True, + ) .withColumn("email_address", "string", expr="concat(email_user_name, '@', lower(customer_name), '.com')") .withColumn("payment_terms", "string", values=["DUE_ON_RECEIPT", "NET30", "NET60", "NET120"]) .withColumn("created_on", "date", begin="2000-01-01", end=self.DEFAULT_START_DATE, interval="1 DAY") .withColumn("created_by", "integer", minValue=1_000, maxValue=9_999, random=True) .withColumn("is_updated", "boolean", expr="rand() > 0.75", omit=True) .withColumn("updated_after_days", "integer", minValue=0, maxValue=1_000, random=True, omit=True) - .withColumn("updated_on", "date", expr="""case when is_updated then created_on - else date_add(created_on, updated_after_days) end""") + .withColumn( + "updated_on", + "date", + expr="""case when is_updated then created_on + else date_add(created_on, updated_after_days) end""", + ) .withColumn("updated_by_user", "integer", minValue=1_000, maxValue=9_999, random=True, omit=True) .withColumn("updated_by", "integer", expr="case when is_updated then updated_by_user else created_by end") ) # Add dummy values if they were requested: if dummyValues > 0: - customers_data_spec = customers_data_spec.withColumn("dummy", "long", random=True, numColumns=dummyValues, - minValue=1, maxValue=self.MAX_LONG) + customers_data_spec = customers_data_spec.withColumn( + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG + ) return customers_data_spec - def getCarriers(self, sparkSession: SparkSession, *, rows: int, partitions: int, numCarriers: int, dummyValues: int) -> DataGenerator: + def getCarriers( + self, sparkSession: SparkSession, *, rows: int, partitions: int, numCarriers: int, dummyValues: int + ) -> DataGenerator: # Validate the options: if numCarriers is None or numCarriers < 0: numCarriers = self.DEFAULT_NUM_CARRIERS @@ -121,27 +146,33 @@ def getCarriers(self, sparkSession: SparkSession, *, rows: int, partitions: int, .withColumn("carrier_id", "integer", minValue=self.CARRIER_MIN_VALUE, uniqueValues=numCarriers) .withColumn("carrier_name", "string", prefix="CARRIER", baseColumn="carrier_id") .withColumn("phone_number", "string", template="ddd-ddd-dddd") - .withColumn("email_user_name", "string", - values=["shipping", "parcel", "logistics", "carrier"], omit=True) + .withColumn("email_user_name", "string", values=["shipping", "parcel", "logistics", "carrier"], omit=True) .withColumn("email_address", "string", expr="concat(email_user_name, '@', lower(carrier_name), '.com')") .withColumn("created_on", "date", begin="2000-01-01", end=self.DEFAULT_START_DATE, interval="1 DAY") .withColumn("created_by", "integer", minValue=1_000, maxValue=9_999, random=True) .withColumn("is_updated", "boolean", expr="rand() > 0.75", omit=True) .withColumn("updated_after_days", "integer", minValue=0, maxValue=1_000, random=True, omit=True) - .withColumn("updated_on", "date", expr="""case when is_updated then created_on - else date_add(created_on, updated_after_days) end""") + .withColumn( + "updated_on", + "date", + expr="""case when is_updated then created_on + else date_add(created_on, updated_after_days) end""", + ) .withColumn("updated_by_user", "integer", minValue=1_000, maxValue=9_999, random=True, omit=True) .withColumn("updated_by", "integer", expr="case when is_updated then updated_by_user else created_by end") ) # Add dummy values if they were requested: if dummyValues > 0: - carriers_data_spec = carriers_data_spec.withColumn("dummy", "long", random=True, numColumns=dummyValues, - minValue=1, maxValue=self.MAX_LONG) + carriers_data_spec = carriers_data_spec.withColumn( + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG + ) return carriers_data_spec - def getCatalogItems(self, sparkSession: SparkSession, *, rows: int, partitions: int, numCatalogItems: int, dummyValues: int) -> DataGenerator: + def getCatalogItems( + self, sparkSession: SparkSession, *, rows: int, partitions: int, numCatalogItems: int, dummyValues: int + ) -> DataGenerator: if numCatalogItems is None or numCatalogItems < 0: numCatalogItems = self.DEFAULT_NUM_CATALOG_ITEMS if rows is None or rows < 0: @@ -152,8 +183,9 @@ def getCatalogItems(self, sparkSession: SparkSession, *, rows: int, partitions: # Create the base data generation spec: catalog_items_data_spec = ( dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) - .withColumn("catalog_item_id", "integer", minValue=self.CATALOG_ITEM_MIN_VALUE, - uniqueValues=numCatalogItems) + .withColumn( + "catalog_item_id", "integer", minValue=self.CATALOG_ITEM_MIN_VALUE, uniqueValues=numCatalogItems + ) .withColumn("item_name", "string", prefix="ITEM", baseColumn="catalog_item_id") .withColumn("unit_price", "decimal(8,2)", minValue=1.50, maxValue=500.0, random=True) .withColumn("discount_rate", "decimal(3,2)", minValue=0.00, maxValue=9.99, random=True) @@ -164,22 +196,36 @@ def getCatalogItems(self, sparkSession: SparkSession, *, rows: int, partitions: .withColumn("created_by", "integer", minValue=1_000, maxValue=9_999, random=True) .withColumn("is_updated", "boolean", expr="rand() > 0.75", omit=True) .withColumn("updated_after_days", "integer", minValue=0, maxValue=1_000, random=True, omit=True) - .withColumn("updated_on", "date", expr="""case when is_updated then created_on - else date_add(created_on, updated_after_days) end""") + .withColumn( + "updated_on", + "date", + expr="""case when is_updated then created_on + else date_add(created_on, updated_after_days) end""", + ) .withColumn("updated_by_user", "integer", minValue=1_000, maxValue=9_999, random=True, omit=True) .withColumn("updated_by", "integer", expr="case when is_updated then updated_by_user else created_by end") ) # Add dummy values if they were requested: if dummyValues > 0: - catalog_items_data_spec = ( - catalog_items_data_spec.withColumn("dummy", "long", random=True, - numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG)) + catalog_items_data_spec = catalog_items_data_spec.withColumn( + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG + ) return catalog_items_data_spec - def getBaseOrders(self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, numCustomers: int, startDate: str, - endDate: str, dummyValues: int) -> DataGenerator: + def getBaseOrders( + self, + sparkSession: SparkSession, + *, + rows: int, + partitions: int, + numOrders: int, + numCustomers: int, + startDate: str, + endDate: str, + dummyValues: int, + ) -> DataGenerator: # Validate the options: if numOrders is None or numOrders < 0: numOrders = self.DEFAULT_NUM_ORDERS @@ -199,14 +245,22 @@ def getBaseOrders(self, sparkSession: SparkSession, *, rows: int, partitions: in dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) .withColumn("order_id", "integer", minValue=self.ORDER_MIN_VALUE, uniqueValues=numOrders) .withColumn("order_title", "string", prefix="ORDER", baseColumn="order_id") - .withColumn("customer_id", "integer", minValue=self.CUSTOMER_MIN_VALUE, - maxValue=self.CUSTOMER_MIN_VALUE + numCustomers, random=True) + .withColumn( + "customer_id", + "integer", + minValue=self.CUSTOMER_MIN_VALUE, + maxValue=self.CUSTOMER_MIN_VALUE + numCustomers, + random=True, + ) .withColumn("purchase_order_number", "string", template="KKKK-KKKK-DDDD-KKKK", random=True) - .withColumn("order_open_date", "date", begin=startDate, end=endDate, - interval="1 DAY", random=True) + .withColumn("order_open_date", "date", begin=startDate, end=endDate, interval="1 DAY", random=True) .withColumn("order_open_to_close_days", "integer", minValue=0, maxValue=30, random=True, omit=True) - .withColumn("order_close_date", "date", expr=f"""least(cast('{endDate}' as date), - date_add(order_open_date, order_open_to_close_days))""") + .withColumn( + "order_close_date", + "date", + expr=f"""least(cast('{endDate}' as date), + date_add(order_open_date, order_open_to_close_days))""", + ) .withColumn("sales_rep_id", "integer", minValue=1_000, maxValue=9_999, random=True) .withColumn("sales_group_id", "integer", minValue=100, maxValue=999, random=True) .withColumn("created_on", "date", expr="order_open_date") @@ -219,12 +273,22 @@ def getBaseOrders(self, sparkSession: SparkSession, *, rows: int, partitions: in # Add dummy values if they were requested: if dummyValues > 0: base_orders_data_spec = base_orders_data_spec.withColumn( - "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG) + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG + ) return base_orders_data_spec - def getBaseOrderLineItems(self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, numCatalogItems: int, - lineItemsPerOrder: int, dummyValues: int) -> DataGenerator: + def getBaseOrderLineItems( + self, + sparkSession: SparkSession, + *, + rows: int, + partitions: int, + numOrders: int, + numCatalogItems: int, + lineItemsPerOrder: int, + dummyValues: int, + ) -> DataGenerator: if numOrders is None or numOrders < 0: numOrders = self.DEFAULT_NUM_ORDERS if numCatalogItems is None or numCatalogItems < 0: @@ -239,13 +303,28 @@ def getBaseOrderLineItems(self, sparkSession: SparkSession, *, rows: int, partit # Create the base data generation spec: base_order_line_items_data_spec = ( dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) - .withColumn("order_line_item_id", "integer", minValue=self.ORDER_LINE_ITEM_MIN_VALUE, - uniqueValues=numOrders * lineItemsPerOrder) - .withColumn("order_id", "integer", minValue=self.ORDER_MIN_VALUE, maxValue=self.ORDER_MIN_VALUE + numOrders, - uniqueValues=numOrders, random=True) - .withColumn("catalog_item_id", "integer", minValue=self.CATALOG_ITEM_MIN_VALUE, - maxValue=self.CATALOG_ITEM_MIN_VALUE + numCatalogItems, uniqueValues=numCatalogItems, - random=True) + .withColumn( + "order_line_item_id", + "integer", + minValue=self.ORDER_LINE_ITEM_MIN_VALUE, + uniqueValues=numOrders * lineItemsPerOrder, + ) + .withColumn( + "order_id", + "integer", + minValue=self.ORDER_MIN_VALUE, + maxValue=self.ORDER_MIN_VALUE + numOrders, + uniqueValues=numOrders, + random=True, + ) + .withColumn( + "catalog_item_id", + "integer", + minValue=self.CATALOG_ITEM_MIN_VALUE, + maxValue=self.CATALOG_ITEM_MIN_VALUE + numCatalogItems, + uniqueValues=numCatalogItems, + random=True, + ) .withColumn("has_discount", "boolean", expr="rand() > 0.9") .withColumn("units", "integer", minValue=1, maxValue=100, random=True) .withColumn("added_after_order_creation_days", "integer", minValue=0, maxValue=30, random=True) @@ -254,11 +333,21 @@ def getBaseOrderLineItems(self, sparkSession: SparkSession, *, rows: int, partit # Add dummy values if they were requested: if dummyValues > 0: base_order_line_items_data_spec = base_order_line_items_data_spec.withColumn( - "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG) + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG + ) return base_order_line_items_data_spec - def getBaseOrderShipments(self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, numCarriers: int, dummyValues: int) -> DataGenerator: + def getBaseOrderShipments( + self, + sparkSession: SparkSession, + *, + rows: int, + partitions: int, + numOrders: int, + numCarriers: int, + dummyValues: int, + ) -> DataGenerator: # Validate the options: if numOrders is None or numOrders < 0: numOrders = self.DEFAULT_NUM_ORDERS @@ -273,16 +362,33 @@ def getBaseOrderShipments(self, sparkSession: SparkSession, *, rows: int, partit base_order_shipments_data_spec = ( dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) .withColumn("order_shipment_id", "integer", minValue=self.ORDER_MIN_VALUE, uniqueValues=numOrders) - .withColumn("order_id", "integer", minValue=self.ORDER_MIN_VALUE, maxValue=self.ORDER_MIN_VALUE + numOrders, - uniqueValues=numOrders, random=True) - .withColumn("carrier_id", "integer", minValue=self.CARRIER_MIN_VALUE, - maxValue=self.CARRIER_MIN_VALUE + numCarriers, uniqueValues=numCarriers, random=True) + .withColumn( + "order_id", + "integer", + minValue=self.ORDER_MIN_VALUE, + maxValue=self.ORDER_MIN_VALUE + numOrders, + uniqueValues=numOrders, + random=True, + ) + .withColumn( + "carrier_id", + "integer", + minValue=self.CARRIER_MIN_VALUE, + maxValue=self.CARRIER_MIN_VALUE + numCarriers, + uniqueValues=numCarriers, + random=True, + ) .withColumn("house_number", "integer", minValue=1, maxValue=9999, random=True, omit=True) .withColumn("street_number", "integer", minValue=1, maxValue=150, random=True, omit=True) - .withColumn("street_direction", "string", values=["", "N", "S", "E", "W", "NW", "NE", "SW", "SE"], - random=True) - .withColumn("ship_to_address_line", "string", expr="""concat_ws(' ', house_number, street_direction, - street_number, 'ST')""") + .withColumn( + "street_direction", "string", values=["", "N", "S", "E", "W", "NW", "NE", "SW", "SE"], random=True + ) + .withColumn( + "ship_to_address_line", + "string", + expr="""concat_ws(' ', house_number, street_direction, + street_number, 'ST')""", + ) .withColumn("ship_to_country_code", "string", values=["US", "CA"], weights=[8, 3], random=True) .withColumn("order_open_to_ship_days", "integer", minValue=0, maxValue=30, random=True) .withColumn("estimated_transit_days", "integer", minValue=1, maxValue=5, random=True) @@ -294,11 +400,14 @@ def getBaseOrderShipments(self, sparkSession: SparkSession, *, rows: int, partit # Add dummy values if they were requested: if dummyValues > 0: base_order_shipments_data_spec = base_order_shipments_data_spec.withColumn( - "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG) + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG + ) return base_order_shipments_data_spec - def getBaseInvoices(self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, dummyValues: int) -> DataGenerator: + def getBaseInvoices( + self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, dummyValues: int + ) -> DataGenerator: # Validate the options: if numOrders is None or numOrders < 0: numOrders = self.DEFAULT_NUM_ORDERS @@ -311,14 +420,25 @@ def getBaseInvoices(self, sparkSession: SparkSession, *, rows: int, partitions: base_invoices_data_spec = ( dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) .withColumn("invoice_id", "integer", minValue=self.INVOICE_MIN_VALUE, uniqueValues=numOrders) - .withColumn("order_id", "integer", minValue=self.ORDER_MIN_VALUE, maxValue=self.ORDER_MIN_VALUE + numOrders, - uniqueValues=numOrders, random=True) + .withColumn( + "order_id", + "integer", + minValue=self.ORDER_MIN_VALUE, + maxValue=self.ORDER_MIN_VALUE + numOrders, + uniqueValues=numOrders, + random=True, + ) .withColumn("house_number", "integer", minValue=1, maxValue=9999, random=True, omit=True) .withColumn("street_number", "integer", minValue=1, maxValue=150, random=True, omit=True) - .withColumn("street_direction", "string", values=["", "N", "S", "E", "W", "NW", "NE", "SW", "SE"], - random=True) - .withColumn("bill_to_address_line", "string", expr="""concat_ws(' ', house_number, street_direction, - street_number, 'ST')""") + .withColumn( + "street_direction", "string", values=["", "N", "S", "E", "W", "NW", "NE", "SW", "SE"], random=True + ) + .withColumn( + "bill_to_address_line", + "string", + expr="""concat_ws(' ', house_number, street_direction, + street_number, 'ST')""", + ) .withColumn("bill_to_country_code", "string", values=["US", "CA"], weights=[8, 3], random=True) .withColumn("order_close_to_invoice_days", "integer", minValue=0, maxValue=5, random=True) .withColumn("order_close_to_create_days", "integer", minValue=0, maxValue=2, random=True) @@ -332,13 +452,32 @@ def getBaseInvoices(self, sparkSession: SparkSession, *, rows: int, partitions: # Add dummy values if they were requested: if dummyValues > 0: base_invoices_data_spec = base_invoices_data_spec.withColumn( - "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG) + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG + ) return base_invoices_data_spec - @DatasetProvider.allowed_options(options=["numCustomers", "numCarriers", "numCatalogItems", "numOrders", - "lineItemsPerOrder", "startDate", "endDate", "dummyValues"]) - def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator: + @DatasetProvider.allowed_options( + options=[ + "numCustomers", + "numCarriers", + "numCatalogItems", + "numOrders", + "lineItemsPerOrder", + "startDate", + "endDate", + "dummyValues", + ] + ) + def getTableGenerator( + self, + sparkSession: SparkSession, + *, + tableName: str | None = None, + rows: int = -1, + partitions: int = -1, + **options: dict[str, Any], + ) -> DataGenerator: # Get the option values: numCustomers = options.get("numCustomers", self.DEFAULT_NUM_CUSTOMERS) numCarriers = options.get("numCarriers", self.DEFAULT_NUM_CARRIERS) @@ -353,27 +492,15 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N spec = None if tableName == "customers": spec = self.getCustomers( - sparkSession, - rows=rows, - partitions=partitions, - numCustomers=numCustomers, - dummyValues=dummyValues + sparkSession, rows=rows, partitions=partitions, numCustomers=numCustomers, dummyValues=dummyValues ) elif tableName == "carriers": spec = self.getCarriers( - sparkSession, - rows=rows, - partitions=partitions, - numCarriers=numCarriers, - dummyValues=dummyValues + sparkSession, rows=rows, partitions=partitions, numCarriers=numCarriers, dummyValues=dummyValues ) elif tableName == "catalog_items": spec = self.getCatalogItems( - sparkSession, - rows=rows, - partitions=partitions, - numCatalogItems=numCatalogItems, - dummyValues=dummyValues + sparkSession, rows=rows, partitions=partitions, numCatalogItems=numCatalogItems, dummyValues=dummyValues ) elif tableName == "base_orders": spec = self.getBaseOrders( @@ -384,7 +511,7 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N numCustomers=numCustomers, startDate=startDate, endDate=endDate, - dummyValues=dummyValues + dummyValues=dummyValues, ) elif tableName == "base_order_line_items": spec = self.getBaseOrderLineItems( @@ -394,7 +521,7 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N numOrders=numOrders, numCatalogItems=numCatalogItems, lineItemsPerOrder=lineItemsPerOrder, - dummyValues=dummyValues + dummyValues=dummyValues, ) elif tableName == "base_order_shipments": spec = self.getBaseOrderShipments( @@ -403,72 +530,88 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N partitions=partitions, numOrders=numOrders, numCarriers=numCarriers, - dummyValues=dummyValues + dummyValues=dummyValues, ) elif tableName == "base_invoices": spec = self.getBaseInvoices( - sparkSession, - rows=rows, - partitions=partitions, - numOrders=numOrders, - dummyValues=dummyValues + sparkSession, rows=rows, partitions=partitions, numOrders=numOrders, dummyValues=dummyValues ) if spec is not None: return spec - raise ValueError("tableName must be 'customers', 'carriers', 'catalog_items', 'base_orders'," - "'base_order_line_items', 'base_order_shipments', 'base_invoices'") + raise ValueError( + "tableName must be 'customers', 'carriers', 'catalog_items', 'base_orders'," + "'base_order_line_items', 'base_order_shipments', 'base_invoices'" + ) - @DatasetProvider.allowed_options(options=[ - "customers", - "carriers", - "catalogItems", - "baseOrders", - "baseOrderLineItems", - "baseOrderShipments", - "baseInvoices" - ]) - def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator: + @DatasetProvider.allowed_options( + options=[ + "customers", + "carriers", + "catalogItems", + "baseOrders", + "baseOrderLineItems", + "baseOrderShipments", + "baseInvoices", + ] + ) + def getAssociatedDataset( + self, + sparkSession: SparkSession, + *, + tableName: str | None = None, + rows: int = -1, + partitions: int = -1, + **options: dict[str, Any], + ) -> DataGenerator: dfCustomers = options.get("customers") - assert dfCustomers is not None and issubclass(type(dfCustomers), DataFrame), \ - "Option `customers` should be a dataframe of customer records" + assert dfCustomers is not None and issubclass( + type(dfCustomers), DataFrame + ), "Option `customers` should be a dataframe of customer records" dfCarriers = options.get("carriers") - assert dfCarriers is not None and issubclass(type(dfCarriers), DataFrame), \ - "Option `carriers` should be dataframe of carrier records" + assert dfCarriers is not None and issubclass( + type(dfCarriers), DataFrame + ), "Option `carriers` should be dataframe of carrier records" dfCatalogItems = options.get("catalogItems") - assert dfCatalogItems is not None and issubclass(type(dfCatalogItems), DataFrame), \ - "Option `catalogItems` should be dataframe of catalog item records" + assert dfCatalogItems is not None and issubclass( + type(dfCatalogItems), DataFrame + ), "Option `catalogItems` should be dataframe of catalog item records" dfBaseOrders = options.get("baseOrders") - assert dfBaseOrders is not None and issubclass(type(dfBaseOrders), DataFrame), \ - "Option `baseOrders` should be dataframe of base order records" + assert dfBaseOrders is not None and issubclass( + type(dfBaseOrders), DataFrame + ), "Option `baseOrders` should be dataframe of base order records" dfBaseOrderLineItems = options.get("baseOrderLineItems") - assert dfBaseOrderLineItems is not None and issubclass(type(dfBaseOrderLineItems), DataFrame), \ - "Option `baseOrderLineItems` should be dataframe of base order line item records" + assert dfBaseOrderLineItems is not None and issubclass( + type(dfBaseOrderLineItems), DataFrame + ), "Option `baseOrderLineItems` should be dataframe of base order line item records" dfBaseOrderShipments = options.get("baseOrderShipments") - assert dfBaseOrderShipments is not None and issubclass(type(dfBaseOrderShipments), DataFrame), \ - "Option `baseOrderLineItems` should be dataframe of base order shipment records" + assert dfBaseOrderShipments is not None and issubclass( + type(dfBaseOrderShipments), DataFrame + ), "Option `baseOrderLineItems` should be dataframe of base order shipment records" dfBaseInvoices = options.get("baseInvoices") - assert dfBaseInvoices is not None and issubclass(type(dfBaseInvoices), DataFrame), \ - "Option `baseInvoices` should be dataframe of base invoice records" + assert dfBaseInvoices is not None and issubclass( + type(dfBaseInvoices), DataFrame + ), "Option `baseInvoices` should be dataframe of base invoice records" if tableName == "orders": dfOrderTotals = ( dfBaseOrderLineItems.alias("a") .join(dfCatalogItems.alias("b"), on="catalog_item_id") - .selectExpr("a.order_id as order_id", - "a.order_line_item_id as order_line_item_id", - """case when a.has_discount then (b.unit_price * 1 - (b.discount_rate / 100)) + .selectExpr( + "a.order_id as order_id", + "a.order_line_item_id as order_line_item_id", + """case when a.has_discount then (b.unit_price * 1 - (b.discount_rate / 100)) else b.unit_price end as unit_price""", - "a.units as units") + "a.units as units", + ) .selectExpr("order_id", "order_line_item_id", "unit_price * units as total_price") .groupBy("order_id") - .agg(F.count("order_line_item_id").alias("num_line_items"), - F.sum("total_price").alias("order_total")) + .agg(F.count("order_line_item_id").alias("num_line_items"), F.sum("total_price").alias("order_total")) ) return ( dfBaseOrders.alias("a") @@ -488,7 +631,8 @@ def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|Non "a.created_on", "a.created_by", "a.updated_on", - "a.updated_by") + "a.updated_by", + ) ) if tableName == "order_line_items": @@ -506,7 +650,8 @@ def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|Non """case when a.has_discount then a.units * c.unit_price * (1 - (c.discount_rate / 100)) else a.units * c.unit_price end as net_price""", "date_add(b.created_on, a.added_after_order_creation_days) as created_on", - "b.created_by") + "b.created_by", + ) ) if tableName == "order_shipments": @@ -525,18 +670,21 @@ def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|Non "a.estimated_transit_days", "a.actual_transit_days", "b.created_on", - "b.created_by") + "b.created_by", + ) ) if tableName == "invoices": dfOrderTotals = ( dfBaseOrderLineItems.alias("a") .join(dfCatalogItems.alias("b"), on="catalog_item_id") - .selectExpr("a.order_id as order_id", - "a.order_line_item_id as order_line_item_id", - """case when a.has_discount then (b.unit_price * 1 - (b.discount_rate / 100)) + .selectExpr( + "a.order_id as order_id", + "a.order_line_item_id as order_line_item_id", + """case when a.has_discount then (b.unit_price * 1 - (b.discount_rate / 100)) else b.unit_price end as unit_price""", - "a.units as units") + "a.units as units", + ) .selectExpr("order_id", "order_line_item_id", "unit_price * units as total_price") .groupBy("order_id") .agg(F.count("order_line_item_id").alias("num_line_items"), F.sum("total_price").alias("order_total")) @@ -561,5 +709,6 @@ def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|Non """case when a.is_updated then date_add(b.order_close_date, a.order_close_to_create_days + a.updated_after_days) else date_add(b.order_close_date, a.order_close_to_create_days) end as updated_on""", - "case when a.is_updated then a.updated_by else a.created_by end as updated_by") + "case when a.is_updated then a.updated_by else a.created_by end as updated_by", + ) ) diff --git a/dbldatagen/datasets/multi_table_telephony_provider.py b/dbldatagen/datasets/multi_table_telephony_provider.py index 646d2d95..9bcbbc79 100644 --- a/dbldatagen/datasets/multi_table_telephony_provider.py +++ b/dbldatagen/datasets/multi_table_telephony_provider.py @@ -8,12 +8,16 @@ from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition -@dataset_definition(name="multi_table/telephony", summary="Multi-table telephony dataset", supportsStreaming=True, - autoRegister=True, - tables=["plans", "customers", "deviceEvents"], - associatedDatasets=["invoices"]) +@dataset_definition( + name="multi_table/telephony", + summary="Multi-table telephony dataset", + supportsStreaming=True, + autoRegister=True, + tables=["plans", "customers", "deviceEvents"], + associatedDatasets=["invoices"], +) class MultiTableTelephonyProvider(DatasetProvider): - """ Telephony multi-table example from documentation + """Telephony multi-table example from documentation See [https://databrickslabs.github.io/dbldatagen/public_docs/multi_table_data.html] @@ -52,6 +56,7 @@ class MultiTableTelephonyProvider(DatasetProvider): be to compute the number of rows from these. """ + MAX_LONG = 9223372036854775807 PLAN_MIN_VALUE = 100 DEFAULT_NUM_PLANS = 20 @@ -62,7 +67,16 @@ class MultiTableTelephonyProvider(DatasetProvider): DEFAULT_NUM_DAYS = 31 DEFAULT_AVG_EVENTS_PER_CUSTOMER = 50 - def getPlans(self, sparkSession: SparkSession, *, rows: int, partitions: int, generateRandom: bool, numPlans: int, dummyValues: int) -> DataGenerator: + def getPlans( + self, + sparkSession: SparkSession, + *, + rows: int, + partitions: int, + generateRandom: bool, + numPlans: int, + dummyValues: int, + ) -> DataGenerator: if numPlans is None or numPlans < 0: numPlans = self.DEFAULT_NUM_PLANS @@ -78,36 +92,70 @@ def getPlans(self, sparkSession: SparkSession, *, rows: int, partitions: int, ge .withColumn("plan_id", "int", minValue=self.PLAN_MIN_VALUE, uniqueValues=numPlans) # use plan_id as root value .withColumn("plan_name", prefix="plan", baseColumn="plan_id") - # note default step is 1 so you must specify a step for small number ranges, - .withColumn("cost_per_mb", "decimal(5,3)", minValue=0.005, maxValue=0.050, - step=0.005, random=generateRandom) - .withColumn("cost_per_message", "decimal(5,3)", minValue=0.001, maxValue=0.02, - step=0.001, random=generateRandom) - .withColumn("cost_per_minute", "decimal(5,3)", minValue=0.001, maxValue=0.01, - step=0.001, random=generateRandom) - + .withColumn( + "cost_per_mb", "decimal(5,3)", minValue=0.005, maxValue=0.050, step=0.005, random=generateRandom + ) + .withColumn( + "cost_per_message", "decimal(5,3)", minValue=0.001, maxValue=0.02, step=0.001, random=generateRandom + ) + .withColumn( + "cost_per_minute", "decimal(5,3)", minValue=0.001, maxValue=0.01, step=0.001, random=generateRandom + ) # we're modelling long distance and international prices simplistically - # each is a multiplier thats applied to base rate - .withColumn("ld_multiplier", "decimal(5,3)", minValue=1.5, maxValue=3, step=0.05, - random=generateRandom, distribution="normal", omit=True) - .withColumn("ld_cost_per_minute", "decimal(5,3)", - expr="cost_per_minute * ld_multiplier", - baseColumns=["cost_per_minute", "ld_multiplier"]) - .withColumn("intl_multiplier", "decimal(5,3)", minValue=2, maxValue=4, step=0.05, - random=generateRandom, distribution="normal", omit=True) - .withColumn("intl_cost_per_minute", "decimal(5,3)", - expr="cost_per_minute * intl_multiplier", - baseColumns=["cost_per_minute", "intl_multiplier"]) + .withColumn( + "ld_multiplier", + "decimal(5,3)", + minValue=1.5, + maxValue=3, + step=0.05, + random=generateRandom, + distribution="normal", + omit=True, + ) + .withColumn( + "ld_cost_per_minute", + "decimal(5,3)", + expr="cost_per_minute * ld_multiplier", + baseColumns=["cost_per_minute", "ld_multiplier"], + ) + .withColumn( + "intl_multiplier", + "decimal(5,3)", + minValue=2, + maxValue=4, + step=0.05, + random=generateRandom, + distribution="normal", + omit=True, + ) + .withColumn( + "intl_cost_per_minute", + "decimal(5,3)", + expr="cost_per_minute * intl_multiplier", + baseColumns=["cost_per_minute", "intl_multiplier"], + ) ) if dummyValues > 0: - plan_dataspec = plan_dataspec.withColumn("dummy", "long", random=True, numColumns=dummyValues, - minValue=1, maxValue=self.MAX_LONG) + plan_dataspec = plan_dataspec.withColumn( + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG + ) return plan_dataspec - def getCustomers(self, sparkSession: SparkSession, *, rows: int, partitions: int, generateRandom: bool, numCustomers: int, numPlans: int, dummyValues: int) -> DataGenerator: + def getCustomers( + self, + sparkSession: SparkSession, + *, + rows: int, + partitions: int, + generateRandom: bool, + numCustomers: int, + numPlans: int, + dummyValues: int, + ) -> DataGenerator: if numCustomers is None or numCustomers < 0: numCustomers = self.DEFAULT_NUM_CUSTOMERS @@ -117,37 +165,52 @@ def getCustomers(self, sparkSession: SparkSession, *, rows: int, partitions: int if partitions is None or partitions < 0: partitions = self.autoComputePartitions(rows, 6 + dummyValues) - customer_dataspec = (dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) - .withColumn("customer_id", "decimal(10)", minValue=self.CUSTOMER_MIN_VALUE, - uniqueValues=numCustomers) - .withColumn("customer_name", template=r"\\w \\w|\\w a. \\w") - - # use the following for a simple sequence - # .withColumn("device_id","decimal(10)", minValue=DEVICE_MIN_VALUE, - # uniqueValues=UNIQUE_CUSTOMERS) - - .withColumn("device_id", "decimal(10)", minValue=self.DEVICE_MIN_VALUE, - baseColumn="customer_id", baseColumnType="hash") - - .withColumn("phone_number", "decimal(10)", minValue=self.SUBSCRIBER_NUM_MIN_VALUE, - baseColumn=["customer_id", "customer_name"], baseColumnType="hash") - - # for email, we'll just use the formatted phone number - .withColumn("email", "string", format="subscriber_%s@myoperator.com", - baseColumn="phone_number") - .withColumn("plan", "int", minValue=self.PLAN_MIN_VALUE, uniqueValues=numPlans, - random=generateRandom) - .withConstraint(dg.constraints.UniqueCombinations(columns=["device_id"])) - .withConstraint(dg.constraints.UniqueCombinations(columns=["phone_number"])) - ) + customer_dataspec = ( + dg.DataGenerator(sparkSession, rows=rows, partitions=partitions) + .withColumn("customer_id", "decimal(10)", minValue=self.CUSTOMER_MIN_VALUE, uniqueValues=numCustomers) + .withColumn("customer_name", template=r"\\w \\w|\\w a. \\w") + # use the following for a simple sequence + # .withColumn("device_id","decimal(10)", minValue=DEVICE_MIN_VALUE, + # uniqueValues=UNIQUE_CUSTOMERS) + .withColumn( + "device_id", + "decimal(10)", + minValue=self.DEVICE_MIN_VALUE, + baseColumn="customer_id", + baseColumnType="hash", + ) + .withColumn( + "phone_number", + "decimal(10)", + minValue=self.SUBSCRIBER_NUM_MIN_VALUE, + baseColumn=["customer_id", "customer_name"], + baseColumnType="hash", + ) + # for email, we'll just use the formatted phone number + .withColumn("email", "string", format="subscriber_%s@myoperator.com", baseColumn="phone_number") + .withColumn("plan", "int", minValue=self.PLAN_MIN_VALUE, uniqueValues=numPlans, random=generateRandom) + .withConstraint(dg.constraints.UniqueCombinations(columns=["device_id"])) + .withConstraint(dg.constraints.UniqueCombinations(columns=["phone_number"])) + ) if dummyValues > 0: - customer_dataspec = customer_dataspec.withColumn("dummy", "long", random=True, numColumns=dummyValues, - minValue=1, maxValue=self.MAX_LONG) + customer_dataspec = customer_dataspec.withColumn( + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG + ) return customer_dataspec - def getDeviceEvents(self, sparkSession: SparkSession, *, rows: int, partitions: int, generateRandom: bool, numCustomers: int, numDays: int, dummyValues: int, - averageEventsPerCustomer: int) -> DataGenerator: + def getDeviceEvents( + self, + sparkSession: SparkSession, + *, + rows: int, + partitions: int, + generateRandom: bool, + numCustomers: int, + numDays: int, + dummyValues: int, + averageEventsPerCustomer: int, + ) -> DataGenerator: MB_100 = 100 * 1000 * 1000 K_1 = 1000 @@ -158,67 +221,105 @@ def getDeviceEvents(self, sparkSession: SparkSession, *, rows: int, partitions: partitions = self.autoComputePartitions(rows, 8 + dummyValues) # use random seed method of 'hash_fieldname' for better spread - default in later builds - events_dataspec = (dg.DataGenerator(sparkSession, rows=rows, partitions=partitions, - randomSeed=42, randomSeedMethod="hash_fieldname") - # use same logic as per customers dataset to ensure matching keys - # but make them random - .withColumn("device_id_base", "decimal(10)", minValue=self.CUSTOMER_MIN_VALUE, - uniqueValues=numCustomers, - random=generateRandom, omit=True) - .withColumn("device_id", "decimal(10)", minValue=self.DEVICE_MIN_VALUE, - baseColumn="device_id_base", baseColumnType="hash") - - # use specific random seed to get better spread of values - .withColumn("event_type", "string", - values=["sms", "internet", "local call", "ld call", "intl call"], - weights=[50, 50, 20, 10, 5], random=generateRandom) - - # use Gamma distribution for skew towards short calls - .withColumn("base_minutes", "decimal(7,2)", - minValue=1.0, maxValue=100.0, step=0.1, - distribution=dg.distributions.Gamma(shape=1.5, scale=2.0), - random=generateRandom, omit=True) - - # use Gamma distribution for skew towards short transfers - .withColumn("base_bytes_transferred", "decimal(12)", - minValue=K_1, maxValue=MB_100, - distribution=dg.distributions.Gamma(shape=0.75, scale=2.0), - random=generateRandom, omit=True) - - .withColumn("minutes", "decimal(7,2)", - baseColumn=["event_type", "base_minutes"], - expr=""" + events_dataspec = ( + dg.DataGenerator( + sparkSession, rows=rows, partitions=partitions, randomSeed=42, randomSeedMethod="hash_fieldname" + ) + # use same logic as per customers dataset to ensure matching keys + # but make them random + .withColumn( + "device_id_base", + "decimal(10)", + minValue=self.CUSTOMER_MIN_VALUE, + uniqueValues=numCustomers, + random=generateRandom, + omit=True, + ) + .withColumn( + "device_id", + "decimal(10)", + minValue=self.DEVICE_MIN_VALUE, + baseColumn="device_id_base", + baseColumnType="hash", + ) + # use specific random seed to get better spread of values + .withColumn( + "event_type", + "string", + values=["sms", "internet", "local call", "ld call", "intl call"], + weights=[50, 50, 20, 10, 5], + random=generateRandom, + ) + # use Gamma distribution for skew towards short calls + .withColumn( + "base_minutes", + "decimal(7,2)", + minValue=1.0, + maxValue=100.0, + step=0.1, + distribution=dg.distributions.Gamma(shape=1.5, scale=2.0), + random=generateRandom, + omit=True, + ) + # use Gamma distribution for skew towards short transfers + .withColumn( + "base_bytes_transferred", + "decimal(12)", + minValue=K_1, + maxValue=MB_100, + distribution=dg.distributions.Gamma(shape=0.75, scale=2.0), + random=generateRandom, + omit=True, + ) + .withColumn( + "minutes", + "decimal(7,2)", + baseColumn=["event_type", "base_minutes"], + expr=""" case when event_type in ("local call", "ld call", "intl call") then base_minutes else 0 end - """) - .withColumn("bytes_transferred", "decimal(12)", - baseColumn=["event_type", "base_bytes_transferred"], - expr=""" + """, + ) + .withColumn( + "bytes_transferred", + "decimal(12)", + baseColumn=["event_type", "base_bytes_transferred"], + expr=""" case when event_type = "internet" then base_bytes_transferred else 0 end - """) - - .withColumn("event_ts", "timestamp", - data_range=dg.DateRange("2020-07-01 00:00:00", - "2020-07-31 11:59:59", - "seconds=1"), - random=True) - - ) + """, + ) + .withColumn( + "event_ts", + "timestamp", + data_range=dg.DateRange("2020-07-01 00:00:00", "2020-07-31 11:59:59", "seconds=1"), + random=True, + ) + ) if dummyValues > 0: - events_dataspec = events_dataspec.withColumn("dummy", "long", random=True, numColumns=dummyValues, - minValue=1, maxValue=self.MAX_LONG) + events_dataspec = events_dataspec.withColumn( + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG + ) return events_dataspec - @DatasetProvider.allowed_options(options=["random", "numPlans", "numCustomers", "dummyValues", "numDays", - "averageEventsPerCustomer"]) - def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator: + @DatasetProvider.allowed_options( + options=["random", "numPlans", "numCustomers", "dummyValues", "numDays", "averageEventsPerCustomer"] + ) + def getTableGenerator( + self, + sparkSession: SparkSession, + *, + tableName: str | None = None, + rows: int = -1, + partitions: int = -1, + **options: dict[str, Any], + ) -> DataGenerator: generateRandom = options.get("random", False) numPlans = options.get("numPlans", self.DEFAULT_NUM_PLANS) numCustomers = options.get("numCustomers", self.DEFAULT_NUM_CUSTOMERS) @@ -227,59 +328,107 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N averageEventsPerCustomer = options.get("averageEventsPerCustomer", self.DEFAULT_AVG_EVENTS_PER_CUSTOMER) if tableName == "plans": - return self.getPlans(sparkSession , rows=rows, partitions=partitions, numPlans=numPlans, - generateRandom=generateRandom, dummyValues=dummyValues) + return self.getPlans( + sparkSession, + rows=rows, + partitions=partitions, + numPlans=numPlans, + generateRandom=generateRandom, + dummyValues=dummyValues, + ) elif tableName == "customers": - return self.getCustomers(sparkSession, rows=rows, partitions=partitions, numCustomers=numCustomers, - generateRandom=generateRandom, numPlans=numPlans, dummyValues=dummyValues) + return self.getCustomers( + sparkSession, + rows=rows, + partitions=partitions, + numCustomers=numCustomers, + generateRandom=generateRandom, + numPlans=numPlans, + dummyValues=dummyValues, + ) elif tableName == "deviceEvents": - return self.getDeviceEvents(sparkSession, rows=rows, partitions=partitions, generateRandom=generateRandom, - numCustomers=numCustomers, numDays=numDays, dummyValues=dummyValues, - averageEventsPerCustomer=averageEventsPerCustomer) + return self.getDeviceEvents( + sparkSession, + rows=rows, + partitions=partitions, + generateRandom=generateRandom, + numCustomers=numCustomers, + numDays=numDays, + dummyValues=dummyValues, + averageEventsPerCustomer=averageEventsPerCustomer, + ) @DatasetProvider.allowed_options(options=["plans", "customers", "deviceEvents"]) - def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, - **options: dict[str, Any]) -> DataGenerator: + def getAssociatedDataset( + self, + sparkSession: SparkSession, + *, + tableName: str | None = None, + rows: int = -1, + partitions: int = -1, + **options: dict[str, Any], + ) -> DataGenerator: dfPlans = options.get("plans") - assert dfPlans is not None and issubclass(type(dfPlans), DataFrame), "Option `plans` should be a dataframe of plan records" + assert dfPlans is not None and issubclass( + type(dfPlans), DataFrame + ), "Option `plans` should be a dataframe of plan records" dfCustomers = options.get("customers") - assert dfCustomers is not None and issubclass(type(dfCustomers), DataFrame), \ - "Option `customers` should be dataframe of customer records" + assert dfCustomers is not None and issubclass( + type(dfCustomers), DataFrame + ), "Option `customers` should be dataframe of customer records" dfDeviceEvents = options.get("deviceEvents") - assert dfDeviceEvents is not None and issubclass(type(dfDeviceEvents), DataFrame), \ - "Option `device_events` should be dataframe of device_event records" + assert dfDeviceEvents is not None and issubclass( + type(dfDeviceEvents), DataFrame + ), "Option `device_events` should be dataframe of device_event records" if tableName == "invoices": df_customer_pricing = dfCustomers.join(dfPlans, dfPlans.plan_id == dfCustomers.plan) # let's compute the summary minutes messages and bytes transferred - df_enriched_events = (dfDeviceEvents - .withColumn("message_count", - F.expr("""case + df_enriched_events = ( + dfDeviceEvents.withColumn( + "message_count", + F.expr( + """case when event_type='sms' then 1 - else 0 end""")) - .withColumn("ld_minutes", - F.expr("""case + else 0 end""" + ), + ) + .withColumn( + "ld_minutes", + F.expr( + """case when event_type='ld call' then cast(ceil(minutes) as decimal(18,3)) - else 0.0 end""")) - .withColumn("local_minutes", - F.expr("""case when event_type='local call' + else 0.0 end""" + ), + ) + .withColumn( + "local_minutes", + F.expr( + """case when event_type='local call' then cast(ceil(minutes) as decimal(18,3)) - else 0.0 end""")) - .withColumn("intl_minutes", - F.expr("""case when event_type='intl call' + else 0.0 end""" + ), + ) + .withColumn( + "intl_minutes", + F.expr( + """case when event_type='intl call' then cast(ceil(minutes) as decimal(18,3)) - else 0.0 end""")) - ) + else 0.0 end""" + ), + ) + ) df_enriched_events.createOrReplaceTempView("mtp_telephony_events") # compute summary activity - df_summary = sparkSession.sql("""select device_id, + df_summary = sparkSession.sql( + """select device_id, round(sum(bytes_transferred) / 1000000.0, 3) as total_mb, sum(message_count) as total_messages, sum(ld_minutes) as total_ld_minutes, @@ -289,12 +438,16 @@ def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|Non from mtp_telephony_events group by device_id - """) + """ + ) df_summary.createOrReplaceTempView("mtp_event_summary") - df_customer_pricing.join(df_summary,df_customer_pricing.device_id == df_summary.device_id).createOrReplaceTempView("mtp_customer_summary") + df_customer_pricing.join( + df_summary, df_customer_pricing.device_id == df_summary.device_id + ).createOrReplaceTempView("mtp_customer_summary") - df_invoices = sparkSession.sql(""" + df_invoices = sparkSession.sql( + """ select *, internet_cost + sms_cost + ld_cost + local_cost + intl_cost as total_invoice @@ -317,6 +470,7 @@ def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|Non as sms_cost from mtp_customer_summary) - """) + """ + ) return df_invoices diff --git a/dbldatagen/datasets_object.py b/dbldatagen/datasets_object.py index 067a37a3..22fa5673 100644 --- a/dbldatagen/datasets_object.py +++ b/dbldatagen/datasets_object.py @@ -50,10 +50,7 @@ class Datasets: @classmethod def getProviderDefinitions( - cls, - name: str | None = None, - pattern: str | None = None, - supportsStreaming: bool = False + cls, name: str | None = None, pattern: str | None = None, supportsStreaming: bool = False ) -> list[DatasetProvider.DatasetDefinition]: """ Gets provider definitions for one or more datasets. @@ -64,25 +61,31 @@ def getProviderDefinitions( :returns: List of provider definitions matching input name and pattern. """ if pattern is not None and name is not None: - summary_list = [provider_definition - for provider_definition in DatasetProvider.getRegisteredDatasets().values() - if re.match(pattern, provider_definition.name) and name == provider_definition.name] + summary_list = [ + provider_definition + for provider_definition in DatasetProvider.getRegisteredDatasets().values() + if re.match(pattern, provider_definition.name) and name == provider_definition.name + ] elif pattern is not None: - summary_list = [provider_definition - for provider_definition in DatasetProvider.getRegisteredDatasets().values() - if re.match(pattern, provider_definition.name)] + summary_list = [ + provider_definition + for provider_definition in DatasetProvider.getRegisteredDatasets().values() + if re.match(pattern, provider_definition.name) + ] elif name is not None: - summary_list = [provider_definition - for provider_definition in DatasetProvider.getRegisteredDatasets().values() - if name == provider_definition.name] + summary_list = [ + provider_definition + for provider_definition in DatasetProvider.getRegisteredDatasets().values() + if name == provider_definition.name + ] else: summary_list = list(DatasetProvider.getRegisteredDatasets().values()) # filter for streaming if supportsStreaming: - summary_list_filtered = [provider_definition - for provider_definition in summary_list - if provider_definition.supportsStreaming] + summary_list_filtered = [ + provider_definition for provider_definition in summary_list if provider_definition.supportsStreaming + ] return summary_list_filtered else: return summary_list @@ -95,9 +98,14 @@ def list(cls, pattern: str | None = None, supportsStreaming: bool = False) -> No :param pattern: Pattern to match dataset names. If ``None``, all datasets are listed. :param supportsStreaming: If ``true``, filters out dataset providers that don't support streaming. """ - summary_list = sorted([(providerDefinition.name, providerDefinition.summary) for - providerDefinition in cls.getProviderDefinitions(name=None, pattern=pattern, - supportsStreaming=supportsStreaming)]) + summary_list = sorted( + [ + (providerDefinition.name, providerDefinition.summary) + for providerDefinition in cls.getProviderDefinitions( + name=None, pattern=pattern, supportsStreaming=supportsStreaming + ) + ] + ) print("The followed datasets are registered and available for use:") @@ -180,9 +188,7 @@ def _getNavigator(self) -> NavigatorNode: return self._navigator def _getProviderInstanceAndMetadata( - self, - providerName: str, - supportsStreaming: bool + self, providerName: str, supportsStreaming: bool ) -> tuple[DatasetProvider, DatasetProvider.DatasetDefinition]: """ Gets a dataset provider and definition. @@ -201,8 +207,9 @@ def _getProviderInstanceAndMetadata( providerClass = providerDefinition.providerClass - assert providerClass is not None and DatasetProvider.isValidDataProviderType(providerClass), \ - f"Dataset provider incorrectly configured for name {self._name}" + assert providerClass is not None and DatasetProvider.isValidDataProviderType( + providerClass + ), f"Dataset provider incorrectly configured for name {self._name}" providerInstance = providerClass() @@ -222,8 +229,9 @@ def _get( :returns: `DataGenerator` for the requested table """ - providerInstance, providerDefinition = \ - self._getProviderInstanceAndMetadata(providerName, supportsStreaming=self._streamingRequired) + providerInstance, providerDefinition = self._getProviderInstanceAndMetadata( + providerName, supportsStreaming=self._streamingRequired + ) if tableName is None: tableName = providerDefinition.primaryTable @@ -233,11 +241,7 @@ def _get( raise ValueError(f"Table `{tableName}` not a recognized table option") return providerInstance.getTableGenerator( - self._sparkSession, - tableName=tableName, - rows=rows, - partitions=partitions, - **kwargs + self._sparkSession, tableName=tableName, rows=rows, partitions=partitions, **kwargs ) def get(self, table: str | None = None, rows: int = -1, partitions: int = -1, **kwargs) -> DataGenerator: @@ -384,9 +388,9 @@ def __init__( datasets: Datasets, providerName: str | None = None, tableName: str | None = None, - location: list[str] | None = None + location: list[str] | None = None, ) -> None: - """ Initialization for node + """Initialization for node :param datasets: instance of datasets object :param providerName: provider name for node @@ -403,11 +407,7 @@ def __repr__(self) -> str: return f"Node: (datasets: {self._datasets}, provider: {self._providerName}, loc: {self._location} )" def _addEntry( - self, - datasets: Datasets, - steps: list[str] | None, - providerName: str | None, - tableName: str | None + self, datasets: Datasets, steps: list[str] | None, providerName: str | None, tableName: str | None ) -> NavigatorNode: """ Adds an entry to the dataset navigator. diff --git a/dbldatagen/daterange.py b/dbldatagen/daterange.py index 9dad8298..270261b4 100644 --- a/dbldatagen/daterange.py +++ b/dbldatagen/daterange.py @@ -1,49 +1,63 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +""" +This module defines the DataRange abstract class. +""" + +import datetime import math -from datetime import datetime, timedelta -from .datarange import DataRange -from .utils import parse_time_interval -from .serialization import SerializableToDict +from pyspark.sql.types import DataType +from dbldatagen.datarange import DataRange +from dbldatagen.serialization import SerializableToDict +from dbldatagen.utils import parse_time_interval -class DateRange(DataRange): - """Class to represent Date range - The date range will represented internally using `datetime` for `start` and `end`, and `timedelta` for `interval` +class DateRange(DataRange): + """Represents a date range. The date range will be represented internally using `datetime` for + `begin` and `end`, and `timedelta` for `interval`. When computing ranges for purposes of the sequences, the maximum value will be adjusted to the nearest whole multiple of the interval that is before the `end` value. When converting from a string, datetime is assumed to use local timezone unless specified as part of the format - in keeping with the python `datetime` handling of datetime instances that do not specify a timezone - - :param begin: start of date range as python `datetime` object. If specified as string, converted to datetime - :param end: end of date range as python `datetime` object. If specified as string, converted to datetime - :param interval: interval of date range as python `timedelta` object. - Note parsing format for interval uses standard timedelta parsing not - the `datetime_format` string - :param datetime_format: format for conversion of strings to datetime objects + in keeping with the python `datetime` handling of datetime instances that do not specify a timezone. + :param begin: Start of the date range as a `datetime.datetime` or `str` + :param end: End of the date range as a `datetime.datetime` or `str` + :param interval: Date range increment as a `datetime.timedelta` or `str`; Parsing format for interval uses + standard timedelta parsing not the `datetime_format` string. If not specified, defaults to 1 day. + :param datetime_format: String format for converting strings to `datetime.datetime` """ DEFAULT_UTC_TS_FORMAT = "%Y-%m-%d %H:%M:%S" DEFAULT_DATE_FORMAT = "%Y-%m-%d" - - DEFAULT_START_TIMESTAMP = datetime(year=datetime.now().year - 1, month=1, day=1) - DEFAULT_END_TIMESTAMP = datetime(year=datetime.now().year - 1, month=12, day=31, hour=23, minute=59, second=59) + DEFAULT_START_TIMESTAMP = datetime.datetime(year=datetime.datetime.now().year - 1, month=1, day=1) + DEFAULT_END_TIMESTAMP = datetime.datetime( + year=datetime.datetime.now().year - 1, month=12, day=31, hour=23, minute=59, second=59 + ) DEFAULT_START_DATE = DEFAULT_START_TIMESTAMP.date() DEFAULT_END_DATE = DEFAULT_END_TIMESTAMP.date() DEFAULT_START_DATE_TIMESTAMP = DEFAULT_START_TIMESTAMP.replace(hour=0, minute=0, second=0) DEFAULT_END_DATE_TIMESTAMP = DEFAULT_END_TIMESTAMP.replace(hour=0, minute=0, second=0) - # todo: deduce format from begin and end params - - def __init__(self, begin, end, interval=None, datetime_format=DEFAULT_UTC_TS_FORMAT): - assert begin is not None, "`begin` must be specified" - assert end is not None, "`end` must be specified" + begin: datetime.datetime + end: datetime.datetime + interval: datetime.timedelta + minValue: float + maxValue: float + step: float + + def __init__( + self, + begin: str | datetime.datetime, + end: str | datetime.datetime, + interval: str | datetime.timedelta, + datetime_format: str = DEFAULT_UTC_TS_FORMAT, + ) -> None: self.datetime_format = datetime_format self.begin = begin if not isinstance(begin, str) else self._datetime_from_string(begin, datetime_format) @@ -51,47 +65,67 @@ def __init__(self, begin, end, interval=None, datetime_format=DEFAULT_UTC_TS_FOR self.interval = interval if not isinstance(interval, str) else self._timedelta_from_string(interval) self.minValue = self.begin.timestamp() - - self.maxValue = (self.minValue + self.interval.total_seconds() - * self.computeTimestampIntervals(self.begin, self.end, self.interval)) + self.maxValue = self.minValue + self.interval.total_seconds() * self.computeTimestampIntervals( + self.begin, self.end, self.interval + ) self.step = self.interval.total_seconds() - def _toInitializationDict(self): - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + def _toInitializationDict(self) -> dict[str, object]: + """Converts an object to a Python dictionary. Keys represent the object's constructor arguments. + + :return: Dictionary representation of the object """ _options = { "kind": self.__class__.__name__, - "begin": datetime.strftime(self.begin, self.datetime_format), - "end": datetime.strftime(self.end, self.datetime_format), + "begin": datetime.datetime.strftime(self.begin, self.datetime_format), + "end": datetime.datetime.strftime(self.end, self.datetime_format), "interval": f"INTERVAL {int(self.interval.total_seconds())} SECONDS", - "datetime_format": self.datetime_format + "datetime_format": self.datetime_format, } return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } @classmethod - def _datetime_from_string(cls, date_str, date_format): - """convert string to Python DateTime object using format""" - result = datetime.strptime(date_str, date_format) + def _datetime_from_string(cls, date_str: str, date_format: str) -> datetime.datetime: + """Converts a string to a `datetime.datetime` object using the specified format. + + :param date_str: String to convert to `datetime.datetime` + :param date_format: Format string for conversion + :return: `datetime.datetime` object + """ + result = datetime.datetime.strptime(date_str, date_format) return result @classmethod - def _timedelta_from_string(cls, interval): + def _timedelta_from_string(cls, interval: str | None) -> datetime.timedelta: + """Converts a string to a `datetime.timedelta` object using the specified format. + + :param interval: String to convert to `datetime.timedelta` + :return: `datetime.timedelta` object + """ return cls.parseInterval(interval) @classmethod - def parseInterval(cls, interval_str): + def parseInterval(cls, interval_str: str | None) -> datetime.timedelta: """Parse interval from string""" - assert interval_str is not None, "`interval_str` must be specified" + if interval_str is None: + raise ValueError("Parameter 'interval_str' must be specified") return parse_time_interval(interval_str) @classmethod - def _getDateTime(cls, dt, datetime_format, default_value): + def _getDateTime( + cls, dt: str | datetime.datetime | None, datetime_format: str, default_value: datetime.datetime + ) -> datetime.datetime: + """Gets a `datetime.datetime` object from the specified string, datetime object, or default value. + + :param dt: String to convert to `datetime.datetime` + :param datetime_format: Format string for conversion + :param default_value: Default value to return if `dt` is None + :return: `datetime.datetime` object + """ if isinstance(dt, str): effective_dt = cls._datetime_from_string(dt, datetime_format) elif dt is None: @@ -101,7 +135,15 @@ def _getDateTime(cls, dt, datetime_format, default_value): return effective_dt @classmethod - def _getInterval(cls, interval, default_value): + def _getInterval( + cls, interval: str | datetime.timedelta | None, default_value: datetime.timedelta + ) -> datetime.timedelta: + """Gets a `datetime.timedelta` object from the specified string, timedelta object, or default value. + + :param interval: String to convert to `datetime.timedelta` + :param default_value: Default value to return if `interval` is None + :return: `datetime.timedelta` object + """ if isinstance(interval, str): effective_interval = parse_time_interval(interval) elif interval is None: @@ -111,15 +153,28 @@ def _getInterval(cls, interval, default_value): return effective_interval @classmethod - def computeDateRange(cls, begin, end, interval, unique_values): - effective_interval = cls._getInterval(interval, timedelta(days=1)) - effective_end = cls._getDateTime(end, DateRange.DEFAULT_DATE_FORMAT, - cls.DEFAULT_END_DATE_TIMESTAMP) + def computeDateRange( + cls, + begin: datetime.datetime | None, + end: datetime.datetime | None, + interval: str | datetime.timedelta | None, + unique_values: int | None, + ) -> "DateRange": + """Computes a date range from the specified begin, end, interval, and unique values. + + :param begin: Start of the date range as a `datetime.datetime` or `str` + :param end: End of the date range as a `datetime.datetime` or `str` + :param interval: Date range increment as a `datetime.timedelta` or `str` + :param unique_values: Number of unique values to generate + :return: `DateRange` object + """ + effective_interval = cls._getInterval(interval, datetime.timedelta(days=1)) + effective_end = cls._getDateTime(end, DateRange.DEFAULT_DATE_FORMAT, cls.DEFAULT_END_DATE_TIMESTAMP) effective_begin = cls._getDateTime(begin, DateRange.DEFAULT_DATE_FORMAT, cls.DEFAULT_START_DATE_TIMESTAMP) if unique_values is not None: - assert type(unique_values) is int, "unique_values must be integer" - assert unique_values >= 1, "unique_values must be positive integer" + if unique_values < 1: + raise ValueError("Parameter 'unique_values' must be a positive integer") effective_begin = effective_end - effective_interval * (unique_values - 1) @@ -127,62 +182,92 @@ def computeDateRange(cls, begin, end, interval, unique_values): return result @classmethod - def computeTimestampRange(cls, begin, end, interval, unique_values): + def computeTimestampRange( + cls, + begin: datetime.datetime | None, + end: datetime.datetime | None, + interval: str | datetime.timedelta | None, + unique_values: int | None, + ) -> "DateRange": + """Computes a timestamp range from the specified begin, end, interval, and unique values. + + :param begin: Start of the timestamp range as a `datetime.datetime` or `str` + :param end: End of the timestamp range as a `datetime.datetime` or `str` + :param interval: Timestamp range increment as a `datetime.timedelta` or `str` + :param unique_values: Number of unique values to generate + :return: `DateRange` object + """ - effective_interval = cls._getInterval(interval, timedelta(days=1)) + effective_interval = cls._getInterval(interval, datetime.timedelta(days=1)) effective_end = cls._getDateTime(end, DateRange.DEFAULT_UTC_TS_FORMAT, cls.DEFAULT_END_TIMESTAMP) effective_begin = cls._getDateTime(begin, DateRange.DEFAULT_UTC_TS_FORMAT, cls.DEFAULT_START_TIMESTAMP) if unique_values is not None: - assert type(unique_values) is int, "unique_values must be integer" - assert unique_values >= 1, "unique_values must be positive integer" + if unique_values < 1: + raise ValueError("Parameter 'unique_values' must be a positive integer") effective_begin = effective_end - effective_interval * (unique_values - 1) result = DateRange(effective_begin, effective_end, effective_interval) return result - def __str__(self): - """ create string representation of date range""" + def __str__(self) -> str: + """Creates a string representation of the date range. + + :return: String representation of the date range + """ return f"DateRange({self.begin},{self.end},{self.interval} == {self.minValue}, {self.maxValue}, {self.step})" - def computeTimestampIntervals(self, start, end, interval): - """ Compute number of intervals between start and end date """ - assert type(start) is datetime, "Expecting start as type datetime.datetime" - assert type(end) is datetime, "Expecting end as type datetime.datetime" - assert type(interval) is timedelta, "Expecting interval as type datetime.timedelta" + def computeTimestampIntervals( + self, start: datetime.datetime, end: datetime.datetime, interval: datetime.timedelta + ) -> int: + """Computes the number of intervals between the specified start and end dates. + + :param start: Start of the timestamp range as a `datetime.datetime` + :param end: End of the timestamp range as a `datetime.datetime` + :param interval: Timestamp range increment as a `datetime.timedelta` + :return: Number of intervals between the start and end dates + """ i1 = end - start ni1 = i1 / interval return math.floor(ni1) - def isFullyPopulated(self): - """Check if minValue, maxValue and step are specified """ + def isFullyPopulated(self) -> bool: + """Checks if the date range is fully populated. + + :return: True if the date range is fully populated, False otherwise + """ return self.minValue is not None and self.maxValue is not None and self.step is not None - def adjustForColumnDatatype(self, ctype): - """ adjust the range for the column output type + def adjustForColumnDatatype(self, ctype: DataType) -> None: + """Adjusts the date range for the specified column output type. :param ctype: Spark SQL data type for column - """ pass - def getDiscreteRange(self): - """ Divide continuous range into discrete intervals + def getDiscreteRange(self) -> float: + """Gets the discrete range of the date range. - Note does not modify range object. + :return: Discrete range object + """ + return (self.maxValue - self.minValue) * float(1.0 / self.step) - :returns: range from minValue to maxValue + def isEmpty(self) -> bool: + """Checks if the date range is empty. + :return: True if the date range is empty, False otherwise """ - return (self.maxValue - self.minValue) * float(1.0 / self.step) + return False - def isEmpty(self): - """Check if object is empty (i.e all instance vars of note are `None`)""" - return self.begin is None and self.end is None and self.interval is None + def getContinuousRange(self) -> float: + """Gets the continuous range of the date range. - def getContinuousRange(self): - """Convert range to continuous range""" - return (self.maxValue - self.minValue) * float(1.0) + :return: Continuous range as a float + """ + return (self.maxValue - self.minValue) * 1.0 + + def getScale(self) -> int: + """Gets the scale of the date range. - def getScale(self): - """Get scale of range""" + :return: Scale of the date range + """ return 0 diff --git a/dbldatagen/distributions/__init__.py b/dbldatagen/distributions/__init__.py index ca1dcece..49095346 100644 --- a/dbldatagen/distributions/__init__.py +++ b/dbldatagen/distributions/__init__.py @@ -26,8 +26,4 @@ from .exponential_distribution import Exponential -__all__ = ["normal_distribution", - "gamma", - "beta", - "data_distribution", - "exponential_distribution"] +__all__ = ["normal_distribution", "gamma", "beta", "data_distribution", "exponential_distribution"] diff --git a/dbldatagen/distributions/beta.py b/dbldatagen/distributions/beta.py index ba2f4858..6faa0898 100644 --- a/dbldatagen/distributions/beta.py +++ b/dbldatagen/distributions/beta.py @@ -7,71 +7,75 @@ """ +import pandas as pd import pyspark.sql.functions as F +from pyspark.sql import Column from pyspark.sql.types import FloatType -import numpy as np -import pandas as pd - -from .data_distribution import DataDistribution -from ..serialization import SerializableToDict +from dbldatagen.datagen_types import NumericLike +from dbldatagen.distributions.data_distribution import DataDistribution +from dbldatagen.serialization import SerializableToDict class Beta(DataDistribution): - """ Specify that random samples should be drawn from the Beta distribution parameterized by alpha and beta - - :param alpha: value for alpha parameter - float, int or other numeric value, greater than 0 - :param beta: value for beta parameter - float, int or other numeric value, greater than 0 + """Specify that random samples should be drawn from the Beta distribution parameterized by alpha and beta. By + default the Beta distribution produces values between 0 and 1 so no scaling is needed. See + https://en.wikipedia.org/wiki/Beta_distribution. - See https://en.wikipedia.org/wiki/Beta_distribution - - By default the Beta distribution produces values between 0 and 1 so no scaling is needed + :param alpha: Alpha parameter value; Should be a float, int or other numeric value, greater than 0 + :param beta: Beta parameter value; Should be a float, int or other numeric value, greater than 0 """ - def __init__(self, alpha=None, beta=None): + def __init__(self, alpha: NumericLike | None = None, beta: NumericLike | None = None) -> None: DataDistribution.__init__(self) - - assert type(alpha) in [float, int, np.float64, np.int32, np.int64], "alpha must be int-like or float-like" - assert type(beta) in [float, int, np.float64, np.int32, np.int64], "beta must be int-like or float-like" self._alpha = alpha self._beta = beta - def _toInitializationDict(self): - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + def _toInitializationDict(self) -> dict[str, object]: + """Converts an object to a Python dictionary. Keys represent the object's + constructor arguments. + + :return: Python dictionary representation of the object """ _options = {"kind": self.__class__.__name__, "alpha": self._alpha, "beta": self._beta} return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } @property - def alpha(self): - """ Return alpha parameter.""" + def alpha(self) -> NumericLike | None: + """Returns the alpha parameter value. + + :return: Alpha parameter value + """ return self._alpha @property - def beta(self): - """ Return beta parameter.""" + def beta(self) -> NumericLike | None: + """Returns the beta parameter value. + + :return: Beta parameter value + """ return self._beta - def __str__(self): - """ Return string representation of object""" + def __str__(self) -> str: + """Returns a string representation of the object. + + :return: String representation of the object + """ return f"BetaDistribution(alpha={self._alpha}, beta={self._beta}, randomSeed={self.randomSeed})" @staticmethod def beta_func(alpha_series: pd.Series, beta_series: pd.Series, random_seed: pd.Series) -> pd.Series: - """ Generate sample of beta distribution using pandas / numpy + """Generates samples from the beta distribution using pandas / numpy. :param alpha_series: value for alpha parameter as Pandas Series :param beta_series: value for beta parameter as Pandas Series :param random_seed: value for randomSeed parameter as Pandas Series :return: random samples from distribution scaled to values between 0 and 1 - """ alpha = alpha_series.to_numpy() beta = beta_series.to_numpy() @@ -82,14 +86,14 @@ def beta_func(alpha_series: pd.Series, beta_series: pd.Series, random_seed: pd.S results = rng.beta(alpha, beta) return pd.Series(results) - def generateNormalizedDistributionSample(self): - """ Generate sample of data for distribution + def generateNormalizedDistributionSample(self) -> Column: + """Generates a sample of data for the distribution. - :return: random samples from distribution scaled to values between 0 and 1 + :return: Pyspark SQL column expression for the sample values """ - beta_sample = F.pandas_udf(self.beta_func, returnType=FloatType()).asNondeterministic() + beta_sample = F.pandas_udf(self.beta_func, returnType=FloatType()).asNondeterministic() # type: ignore - newDef = beta_sample(F.lit(self._alpha), - F.lit(self._beta), - F.lit(self.randomSeed) if self.randomSeed is not None else F.lit(-1)) + newDef: Column = beta_sample( + F.lit(self._alpha), F.lit(self._beta), F.lit(self.randomSeed) if self.randomSeed is not None else F.lit(-1) + ) return newDef diff --git a/dbldatagen/distributions/data_distribution.py b/dbldatagen/distributions/data_distribution.py index d8abe6cc..5e7839b8 100644 --- a/dbldatagen/distributions/data_distribution.py +++ b/dbldatagen/distributions/data_distribution.py @@ -23,73 +23,72 @@ from abc import ABC, abstractmethod import numpy as np -import pyspark.sql.functions as F +from pyspark.sql import Column -from ..serialization import SerializableToDict +from dbldatagen.serialization import SerializableToDict class DataDistribution(SerializableToDict, ABC): - """ Base class for all distributions""" + """Base class for all distributions""" - def __init__(self): - self._rounding = False - self._randomSeed = None + _randomSeed: int | np.int32 | np.int64 | None = None + _rounding: bool = False @staticmethod - def get_np_random_generator(random_seed): - """ Get numpy random number generator + def get_np_random_generator(random_seed: int | np.int32 | np.int64 | None) -> np.random.Generator: + """Gets a numpy random number generator. - :param random_seed: Numeric random seed to use. If < 0, then no random - :return: + :param random_seed: Numeric random seed to use; If < 0, then no random + :return: Numpy random number generator """ - assert random_seed is None or type(random_seed) in [np.int32, np.int64, int], \ - f"`randomSeed` must be int or int-like not {type(random_seed)}" - from numpy.random import default_rng if random_seed not in (-1, -1.0): - rng = default_rng(random_seed) + rng = np.random.default_rng(random_seed) else: - rng = default_rng() - + rng = np.random.default_rng() return rng @abstractmethod - def generateNormalizedDistributionSample(self): - """ Generate sample of data for distribution - - :return: random samples from distribution scaled to values between 0 and 1 + def generateNormalizedDistributionSample(self) -> Column: + """Generates a sample of data for the distribution. Implementors must provide an implementation for this method. - Note implementors should provide implementation for this, - - Return value is expected to be a Pyspark SQL column expression such as F.expr("rand()") + :return: Pyspark SQL column expression for the sample """ - pass + raise NotImplementedError( + f"Class '{self.__class__.__name__}' does not implement 'generateNormalizedDistributionSample'" + ) - def withRounding(self, rounding): - """ Create copy of object and set the rounding attribute + def withRounding(self, rounding: bool) -> "DataDistribution": + """Creates a copy of the object and sets the rounding attribute. - :param rounding: rounding value to set. Should be True or False - :return: new instance of data distribution object with rounding set + :param rounding: Rounding value to set + :return: New instance of data distribution object with rounding set """ new_distribution_instance = copy.copy(self) new_distribution_instance._rounding = rounding return new_distribution_instance @property - def rounding(self): - """get the `rounding` attribute """ + def rounding(self) -> bool: + """Returns the rounding attribute. + + :return: Rounding attribute + """ return self._rounding - def withRandomSeed(self, seed): - """ Create copy of object and set the random seed attribute + def withRandomSeed(self, seed: int | np.int32 | np.int64 | None) -> "DataDistribution": + """Creates a copy of the object and with a new random seed value. - :param seed: random generator seed value to set. Should be integer, float or None - :return: new instance of data distribution object with rounding set + :param seed: Random generator seed value to set; Should be integer, float or None + :return: New instance of data distribution object with random seed set """ new_distribution_instance = copy.copy(self) new_distribution_instance._randomSeed = seed return new_distribution_instance @property - def randomSeed(self): - """get the `randomSeed` attribute """ + def randomSeed(self) -> int | np.int32 | np.int64 | None: + """Returns the random seed attribute. + + :return: Random seed attribute + """ return self._randomSeed diff --git a/dbldatagen/distributions/exponential_distribution.py b/dbldatagen/distributions/exponential_distribution.py index af9751d2..c924b447 100644 --- a/dbldatagen/distributions/exponential_distribution.py +++ b/dbldatagen/distributions/exponential_distribution.py @@ -9,52 +9,53 @@ import numpy as np import pandas as pd - import pyspark.sql.functions as F +from pyspark.sql import Column from pyspark.sql.types import FloatType -from .data_distribution import DataDistribution -from ..serialization import SerializableToDict +from dbldatagen.datagen_types import NumericLike +from dbldatagen.distributions.data_distribution import DataDistribution +from dbldatagen.serialization import SerializableToDict class Exponential(DataDistribution): - """ Specify that random samples should be drawn from the exponential distribution parameterized by rate - - :param rate: value for rate parameter - float, int or other numeric value, greater than 0 - - See https://en.wikipedia.org/wiki/Exponential_distribution + """Specifies that random samples should be drawn from the exponential distribution parameterized + by rate. See https://en.wikipedia.org/wiki/Exponential_distribution - Scaling is performed to normalize values between 0 and 1 + :param rate: Value for rate parameter; Should be a float, int or other numeric value greater than 0 """ - def __init__(self, rate=None): + def __init__(self, rate: NumericLike | None = None) -> None: DataDistribution.__init__(self) self._rate = rate - def _toInitializationDict(self): - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + def _toInitializationDict(self) -> dict[str, object]: + """Converts an object to a Python dictionary. Keys represent the object's + constructor arguments. + + :return: Dictionary representation of the object """ _options = {"kind": self.__class__.__name__, "rate": self._rate} return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } - def __str__(self): - """ Return string representation""" + def __str__(self) -> str: + """Returns a string representation of the object. + + :return: String representation of the object + """ return f"ExponentialDistribution(rate={self.rate}, randomSeed={self.randomSeed})" @staticmethod def exponential_func(scale_series: pd.Series, random_seed: pd.Series) -> pd.Series: - """ Generate sample of exponential distribution using pandas / numpy - - :param scale_series: value for scale parameter as Pandas Series - :param random_seed: value for randomSeed parameter as Pandas Series - :return: random samples from distribution scaled to values between 0 and 1 + """Generates samples from the exponential distribution using pandas / numpy. + :param scale_series: Value for scale parameter as Pandas Series + :param random_seed: Value for randomSeed parameter as Pandas Series + :return: Random samples from distribution scaled to values between 0 and 1 """ scale_param = scale_series.to_numpy() random_seed = random_seed.to_numpy()[0] @@ -76,23 +77,37 @@ def exponential_func(scale_series: pd.Series, random_seed: pd.Series) -> pd.Seri return pd.Series(results2) @property - def rate(self): - """ Return rate parameter""" + def rate(self) -> NumericLike | None: + """Returns the rate parameter. + + :return: Rate parameter + """ return self._rate @property - def scale(self): - """ Return scale implicit parameter. Scale is 1/rate""" + def scale(self) -> float | np.floating | None: + """Returns the scale implicit parameter. Scale is 1/rate. + + :return: Scale implicit parameter + """ + if not self._rate: + raise ValueError("Cannot compute value for 'scale'; Missing value for 'rate'") return 1.0 / self._rate - def generateNormalizedDistributionSample(self): - """ Generate sample of data for distribution + def generateNormalizedDistributionSample(self) -> Column: + """Generates a sample of data for the distribution. - :return: random samples from distribution scaled to values between 0 and 1 + :return: Pyspark SQL column expression for the sample values """ - exponential_sample = F.pandas_udf(self.exponential_func, returnType=FloatType()).asNondeterministic() + if not self._rate: + raise ValueError("Cannot compute value for 'scale'; Missing value for 'rate'") + + exponential_sample = F.pandas_udf( # type: ignore + self.exponential_func, returnType=FloatType() + ).asNondeterministic() # scala formulation uses scale = 1/rate - newDef = exponential_sample(F.lit(1.0 / self._rate), - F.lit(self.randomSeed) if self.randomSeed is not None else F.lit(-1.0)) + newDef: Column = exponential_sample( + F.lit(1.0 / self._rate), F.lit(self.randomSeed) if self.randomSeed is not None else F.lit(-1.0) + ) return newDef diff --git a/dbldatagen/distributions/gamma.py b/dbldatagen/distributions/gamma.py index 9fa67a23..6af3074c 100644 --- a/dbldatagen/distributions/gamma.py +++ b/dbldatagen/distributions/gamma.py @@ -10,66 +10,72 @@ import numpy as np import pandas as pd import pyspark.sql.functions as F +from pyspark.sql import Column from pyspark.sql.types import FloatType -from .data_distribution import DataDistribution -from ..serialization import SerializableToDict +from dbldatagen.datagen_types import NumericLike +from dbldatagen.distributions.data_distribution import DataDistribution +from dbldatagen.serialization import SerializableToDict class Gamma(DataDistribution): - """ Specify Gamma distribution with specific shape and scale - - :param shape: shape parameter (k) - :param scale: scale parameter (theta) - - See https://en.wikipedia.org/wiki/Gamma_distribution - - Scaling is performed to normalize values between 0 and 1 + """Specifies that random samples should be drawn from the gamma distribution parameterized by shape + and scale. See https://en.wikipedia.org/wiki/Gamma_distribution. + :param shape: Shape parameter; Should be a float, int or other numeric value greater than 0 + :param scale: Scale parameter; Should be a float, int or other numeric value greater than 0 """ - def __init__(self, shape, scale): + def __init__(self, shape: NumericLike | None = None, scale: NumericLike | None = None) -> None: DataDistribution.__init__(self) - assert type(shape) in [float, int, np.float64, np.int32, np.int64], "alpha must be int-like or float-like" - assert type(scale) in [float, int, np.float64, np.int32, np.int64], "beta must be int-like or float-like" self._shape = shape self._scale = scale - def _toInitializationDict(self): - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + def _toInitializationDict(self) -> dict[str, object]: + """Converts an object to a Python dictionary. Keys represent the object's + constructor arguments. + + :return: Dictionary representation of the object """ _options = {"kind": self.__class__.__name__, "shape": self._shape, "scale": self._scale} return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } @property - def shape(self): - """ Return shape parameter.""" + def shape(self) -> NumericLike | None: + """Returns the shape parameter. + + :return: Shape parameter + """ return self._shape @property - def scale(self): - """ Return scale parameter.""" + def scale(self) -> NumericLike | None: + """Returns the scale parameter. + + :return: Scale parameter + """ return self._scale - def __str__(self): - """ Return string representation of object """ + def __str__(self) -> str: + """Returns a string representation of the object. + + :return: String representation of the object + """ return f"GammaDistribution(shape(`k`)={self._shape}, scale(`theta`)={self._scale}, seed={self.randomSeed})" @staticmethod def gamma_func(shape_series: pd.Series, scale_series: pd.Series, random_seed: pd.Series) -> pd.Series: - """ Pandas / Numpy based function to generate gamma samples + """Generates samples from the gamma distribution using pandas / numpy. - :param shape_series: pandas series of shape (k) values - :param scale_series: pandas series of scale (theta) values - :param random_seed: pandas series of random seed values + :param shape_series: Value for shape parameter as Pandas Series + :param scale_series: Value for scale parameter as Pandas Series + :param random_seed: Value for randomSeed parameter as Pandas Series - :return: Samples scaled from 0 .. 1 + :return: Random samples from distribution scaled to values between 0 and 1 """ shape = shape_series.to_numpy() scale = scale_series.to_numpy() @@ -90,14 +96,16 @@ def gamma_func(shape_series: pd.Series, scale_series: pd.Series, random_seed: pd results2 = adjusted_results / scaling_factor return pd.Series(results2) - def generateNormalizedDistributionSample(self): - """ Generate sample of data for distribution + def generateNormalizedDistributionSample(self) -> Column: + """Generates a sample of data for the distribution. - :return: random samples from distribution scaled to values between 0 and 1 + :return: Pyspark SQL column expression for the sample values """ - gamma_sample = F.pandas_udf(self.gamma_func, returnType=FloatType()).asNondeterministic() + gamma_sample = F.pandas_udf(self.gamma_func, returnType=FloatType()).asNondeterministic() # type: ignore - newDef = gamma_sample(F.lit(self._shape), - F.lit(self._scale), - F.lit(self.randomSeed) if self.randomSeed is not None else F.lit(-1.0)) + newDef: Column = gamma_sample( + F.lit(self._shape), + F.lit(self._scale), + F.lit(self.randomSeed) if self.randomSeed is not None else F.lit(-1.0), + ) return newDef diff --git a/dbldatagen/distributions/normal_distribution.py b/dbldatagen/distributions/normal_distribution.py index d36daef5..dc59186c 100644 --- a/dbldatagen/distributions/normal_distribution.py +++ b/dbldatagen/distributions/normal_distribution.py @@ -9,46 +9,48 @@ import numpy as np import pandas as pd - import pyspark.sql.functions as F +from pyspark.sql import Column from pyspark.sql.types import FloatType -from .data_distribution import DataDistribution -from ..serialization import SerializableToDict +from dbldatagen.datagen_types import NumericLike +from dbldatagen.distributions.data_distribution import DataDistribution +from dbldatagen.serialization import SerializableToDict class Normal(DataDistribution): - def __init__(self, mean, stddev): - ''' Specify that data should follow normal distribution + def __init__(self, mean: NumericLike | None = None, stddev: NumericLike | None = None) -> None: + """Specifies that random samples should be drawn from the normal distribution parameterized by mean and standard deviation. - :param mean: mean of distribution - :param stddev: standard deviation of distribution - ''' + :param mean: Value for mean parameter; Should be a float, int or other numeric value + :param stddev: Value for standard deviation parameter; Should be a float, int or other numeric value + """ DataDistribution.__init__(self) self.mean = mean if mean is not None else 0.0 self.stddev = stddev if stddev is not None else 1.0 - def _toInitializationDict(self): - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + def _toInitializationDict(self) -> dict[str, object]: + """Converts an object to a Python dictionary. Keys represent the object's + constructor arguments. + + :return: Dictionary representation of the object """ _options = {"kind": self.__class__.__name__, "mean": self.mean, "stddev": self.stddev} return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } @staticmethod def normal_func(mean_series: pd.Series, std_dev_series: pd.Series, random_seed: pd.Series) -> pd.Series: - """ Pandas / Numpy based function to generate normal / gaussian samples + """Generates samples from the normal distribution using pandas / numpy. - :param mean_series: pandas series of mean values - :param std_dev_series: pandas series of standard deviation values - :param random_seed: pandas series of random seed values + :param mean_series: Value for mean parameter as Pandas Series + :param std_dev_series: Value for standard deviation parameter as Pandas Series + :param random_seed: Value for randomSeed parameter as Pandas Series - :return: Samples scaled from 0 .. 1 + :return: Random samples from distribution scaled to values between 0 and 1 """ mean = mean_series.to_numpy() @@ -70,24 +72,30 @@ def normal_func(mean_series: pd.Series, std_dev_series: pd.Series, random_seed: results2 = adjusted_results / scaling_factor return pd.Series(results2) - def generateNormalizedDistributionSample(self): - """ Generate sample of data for distribution + def generateNormalizedDistributionSample(self) -> Column: + """Generates a sample of data for the distribution. - :return: random samples from distribution scaled to values between 0 and 1 + :return: Pyspark SQL column expression for the sample values """ - normal_sample = F.pandas_udf(self.normal_func, returnType=FloatType()).asNondeterministic() + normal_sample = F.pandas_udf(self.normal_func, returnType=FloatType()).asNondeterministic() # type: ignore # scala formulation uses scale = 1/rate - newDef = normal_sample(F.lit(self.mean), - F.lit(self.stddev), - F.lit(self.randomSeed) if self.randomSeed is not None else F.lit(-1.0)) + newDef: Column = normal_sample( + F.lit(self.mean), F.lit(self.stddev), F.lit(self.randomSeed) if self.randomSeed is not None else F.lit(-1.0) + ) return newDef - def __str__(self): - """ Return string representation of object """ + def __str__(self) -> str: + """Returns a string representation of the object. + + :return: String representation of the object + """ return f"NormalDistribution( mean={self.mean}, stddev={self.stddev}, randomSeed={self.randomSeed})" @classmethod - def standardNormal(cls): - """ return instance of standard normal distribution """ + def standardNormal(cls) -> "Normal": + """Returns an instance of the standard normal distribution with mean 0.0 and standard deviation 1.0. + + :return: Instance of the standard normal distribution + """ return Normal(mean=0.0, stddev=1.0) diff --git a/dbldatagen/multi_table_builder.py b/dbldatagen/multi_table_builder.py new file mode 100644 index 00000000..fb7f693e --- /dev/null +++ b/dbldatagen/multi_table_builder.py @@ -0,0 +1,276 @@ +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module defines the ``MultiTableBuilder`` class used for managing relational datasets. +""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass + +from pyspark.sql import DataFrame + +from dbldatagen.data_generator import DataGenerator +from dbldatagen.datagen_types import ColumnLike +from dbldatagen.relation import ForeignKeyRelation +from dbldatagen.utils import DataGenError, ensure_column + + +class MultiTableBuilder: + """ + Basic builder for managing multiple related datasets backed by ``DataGenerator`` instances. + + This initial implementation focuses on tracking datasets, static DataFrames, and foreign key relations. + Related tables must share a single ``DataGenerator`` so the rows can be generated in a single pass. + """ + + def __init__(self) -> None: + self._datasets: dict[str, _DatasetDefinition] = {} + self._data_generators: list[DataGenerator] = [] + self._static_dataframes: list[DataFrame] = [] + self._foreign_key_relations: list[ForeignKeyRelation] = [] + self._generator_cache: dict[int, DataFrame] = {} + + @property + def data_generators(self) -> list[DataGenerator]: + """ + List of unique ``DataGenerator`` instances tracked by the builder. + """ + return list(dict.fromkeys(self._data_generators)) + + @property + def static_dataframes(self) -> list[DataFrame]: + """ + List of static ``DataFrame`` objects tracked by the builder. + """ + return list(self._static_dataframes) + + @property + def foreign_key_relations(self) -> list[ForeignKeyRelation]: + """ + List of registered :class:`ForeignKeyRelation` objects. + """ + return list(self._foreign_key_relations) + + def add_data_generator( + self, + name: str, + generator: DataGenerator, + columns: Sequence[ColumnLike] | None = None, + ) -> None: + """ + Register a dataset backed by a ``DataGenerator``. + + :param name: Dataset name + :param generator: Generator instance capable of producing all required columns + :param columns: Default column projection for the dataset + """ + if name in self._datasets: + raise DataGenError(f"Dataset '{name}' is already defined.") + + self._datasets[name] = _DatasetDefinition( + name=name, + generator=generator, + columns=tuple(columns) if columns is not None else None, + ) + self._data_generators.append(generator) + + def add_static_dataframe( + self, + name: str, + dataframe: DataFrame, + columns: Sequence[ColumnLike] | None = None, + ) -> None: + """ + Register a dataset backed by a pre-built ``DataFrame``. + + :param name: Dataset name + :param dataframe: Static ``DataFrame`` instance + :param columns: Default column projection for the dataset + """ + if name in self._datasets: + raise DataGenError(f"Dataset '{name}' is already defined.") + + self._datasets[name] = _DatasetDefinition( + name=name, + dataframe=dataframe, + columns=tuple(columns) if columns is not None else None, + ) + self._static_dataframes.append(dataframe) + + def add_foreign_key_relation( + self, + relation: ForeignKeyRelation | None = None, + *, + from_table: str | None = None, + from_column: ColumnLike | None = None, + to_table: str | None = None, + to_column: ColumnLike | None = None, + ) -> ForeignKeyRelation: + """ + Register a foreign key relation between two datasets. + + The relation can be provided via a fully constructed ``ForeignKeyRelation`` or via keyword arguments. + + :param relation: Optional ``ForeignKeyRelation`` instance + :param from_table: Referencing dataset name (required if ``relation`` not supplied) + :param from_column: Referencing column (required if ``relation`` not supplied) + :param to_table: Referenced dataset name (required if ``relation`` not supplied) + :param to_column: Referenced column (required if ``relation`` not supplied) + :return: Registered relation + """ + if relation is None: + if not all([from_table, from_column, to_table, to_column]): + raise DataGenError("Foreign key relation requires table and column details.") + relation = ForeignKeyRelation( + from_table=from_table, # type: ignore[arg-type] + from_column=from_column, # type: ignore[arg-type] + to_table=to_table, # type: ignore[arg-type] + to_column=to_column, # type: ignore[arg-type] + ) + + self._validate_dataset_exists(relation.from_table) + self._validate_dataset_exists(relation.to_table) + + self._foreign_key_relations.append(relation) + return relation + + def build( + self, + dataset_names: Sequence[str] | None = None, + column_overrides: Mapping[str, Sequence[ColumnLike]] | None = None, + ) -> dict[str, DataFrame]: + """ + Materialize one or more datasets managed by the builder. + + :param dataset_names: Optional list of dataset names to build (defaults to all datasets) + :param column_overrides: Optional mapping of dataset name to column overrides + :return: Dictionary keyed by dataset name containing Spark ``DataFrame`` objects + """ + targets = dataset_names or list(self._datasets) + results: dict[str, DataFrame] = {} + + for name in targets: + overrides = column_overrides[name] if column_overrides and name in column_overrides else None + results[name] = self.get_dataset(name, columns=overrides) + + return results + + def get_dataset(self, name: str, columns: Sequence[ColumnLike] | None = None) -> DataFrame: + """ + Retrieve a single dataset as a ``DataFrame`` applying optional column overrides. + + :param name: Dataset name + :param columns: Optional select expressions to override defaults + :return: Spark ``DataFrame`` with projected columns + """ + dataset = self._datasets.get(name) + if dataset is None: + raise DataGenError(f"Dataset '{name}' is not defined.") + + if dataset.dataframe is not None: + return dataset.select_columns(dataset.dataframe, columns) + + assert dataset.generator is not None + self._ensure_shared_generator(name) + + base_df = self._get_or_build_generator_output(dataset.generator) + return dataset.select_columns(base_df, columns) + + def clear_cache(self) -> None: + """ + Clear cached ``DataFrame`` results for generator-backed datasets. + """ + self._generator_cache.clear() + + def _validate_dataset_exists(self, name: str) -> None: + if name not in self._datasets: + raise DataGenError(f"Dataset '{name}' is not registered with the builder.") + + def _get_or_build_generator_output(self, generator: DataGenerator) -> DataFrame: + generator_id = id(generator) + if generator_id not in self._generator_cache: + self._generator_cache[generator_id] = generator.build() + return self._generator_cache[generator_id] + + def _ensure_shared_generator(self, name: str) -> None: + """ + Validate that all generator-backed tables within the relation group share the same generator instance. + """ + dataset = self._datasets[name] + generator = dataset.generator + if generator is None: + return + + for related_name in self._collect_related_tables(name): + related_dataset = self._datasets[related_name] + if related_dataset.generator is None: + continue + if related_dataset.generator is not generator: + msg = ( + f"Datasets '{name}' and '{related_name}' participate in a foreign key relation " + "and must share the same DataGenerator instance." + ) + raise DataGenError(msg) + + def _collect_related_tables(self, name: str) -> set[str]: + """ + Collect all tables connected to the supplied table via foreign key relations. + """ + related: set[str] = set() + to_visit = [name] + + while to_visit: + current = to_visit.pop() + for relation in self._foreign_key_relations: + neighbor = self._neighbor_for_relation(current, relation) + if neighbor and neighbor not in related: + related.add(neighbor) + to_visit.append(neighbor) + + return related + + @staticmethod + def _neighbor_for_relation(table: str, relation: ForeignKeyRelation) -> str | None: + if relation.from_table == table: + return relation.to_table + if relation.to_table == table: + return relation.from_table + return None + + +@dataclass +class _DatasetDefinition: + """ + Internal representation of a dataset tracked by a ``MultiTableBuilder``. + """ + + name: str + generator: DataGenerator | None = None + dataframe: DataFrame | None = None + columns: tuple[ColumnLike, ...] | None = None + + def __post_init__(self) -> None: + has_generator = self.generator is not None + has_dataframe = self.dataframe is not None + + if has_generator == has_dataframe: + raise DataGenError(f"Dataset '{self.name}' must specify exactly one of DataGenerator or DataFrame.") + + def select_columns(self, df: DataFrame, overrides: Sequence[ColumnLike] | None = None) -> DataFrame: + """ + Apply column selection for the dataset using overrides when supplied. + + :param df: Source ``DataFrame`` to project + :param overrides: Optional column expressions to use instead of defaults + :return: Projected ``DataFrame`` + """ + select_exprs = overrides if overrides is not None else self.columns + if not select_exprs: + return df + + normalized_columns = [ensure_column(expr) for expr in select_exprs] + return df.select(*normalized_columns) diff --git a/dbldatagen/nrange.py b/dbldatagen/nrange.py index 283b83da..13734630 100644 --- a/dbldatagen/nrange.py +++ b/dbldatagen/nrange.py @@ -3,154 +3,238 @@ # """ -This module defines the `NRange` class used to specify data ranges +This module defines the `NRange` class used to specify numeric data ranges. """ import math -from pyspark.sql.types import LongType, FloatType, IntegerType, DoubleType, ShortType, \ - ByteType, DecimalType +from pyspark.sql.types import ( + ByteType, + DataType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + ShortType, +) from .datarange import DataRange from .serialization import SerializableToDict -_OLD_MIN_OPTION = 'min' -_OLD_MAX_OPTION = 'max' + +_OLD_MIN_OPTION = "min" +_OLD_MAX_OPTION = "max" class NRange(DataRange): - """ Ranged numeric interval representing the interval minValue .. maxValue inclusive + """Represents a numeric interval for data generation. + + The numeric range represents the interval `minValue .. maxValue` inclusive and can be + used as an alternative to the `minValue`, `maxValue`, and `step` parameters passed to + the `DataGenerator.withColumn` method. Specify by passing an instance of `NRange` + to the `dataRange` parameter. + + For a decreasing sequence, use a negative `step` value. + + :param minValue: Minimum value of range (integer / long / float). + :param maxValue: Maximum value of range (integer / long / float). + :param step: Step value for range (integer / long / float). + :param until: Upper bound for the range (i.e. `maxValue + 1`). - A ranged object can be uses as an alternative to the `minValue`, `maxValue`, `step` parameters - to the DataGenerator `withColumn` and `withColumn` objects. - Specify by passing an instance of `NRange` to the `dataRange` parameter. + You may only specify either `maxValue` or `until`, but not both. For backwards + compatibility, the legacy `min` and `max` keyword arguments are still supported + but `minValue` and `maxValue` are preferred. - :param minValue: Minimum value of range. May be integer / long / float - :param maxValue: Maximum value of range. May be integer / long / float - :param step: Step value for range. May be integer / long / float - :param until: Upper bound for range ( i.e maxValue+1) + .. note:: + The `until` parameter is used to specify the upper bound for the range. It is + equivalent to `maxValue + 1`. You may specify either `maxValue` or `until`, but not both. - You may only specify a `maxValue` or `until` value not both. + .. note:: + The `step` parameter is used to specify the step value for the range. It is + used to generate the range of values. - For a decreasing sequence, use a negative step value. + .. note:: + For backwards compatibility, the legacy `min` and `max` keyword arguments are still + supported. Using `minValue` and `maxValue` is strongly preferred. """ - def __init__(self, minValue=None, maxValue=None, step=None, until=None, **kwArgs): - # check if older form of `minValue` and `maxValue` are used, and if so - if _OLD_MIN_OPTION in kwArgs: - assert minValue is None, \ - "Only one of `minValue` and `minValue` can be specified. Use of `minValue` is preferred" - self.minValue = kwArgs[_OLD_MIN_OPTION] - kwArgs.pop(_OLD_MIN_OPTION, None) + minValue: float | int | None + maxValue: float | int | None + step: float | int | None + + def __init__( + self, + minValue: int | float | None = None, + maxValue: int | float | None = None, + step: int | float | None = None, + until: int | float | None = None, + **kwargs: object, + ) -> None: + """Initializes a numeric range. + + :param minValue: Minimum value of range (integer / long / float). + :param maxValue: Maximum value of range (integer / long / float). + :param step: Step value for range (integer / long / float). + :param until: Upper bound for the range (i.e. `maxValue + 1`). + + You may only specify either `maxValue` or `until`, but not both. For backwards + compatibility, the legacy `min` and `max` keyword arguments are still supported + but `minValue` and `maxValue` are preferred. + """ + # Handle older form of `minValue` and `maxValue` arguments (`min`/`max`) if used. + if _OLD_MIN_OPTION in kwargs: + if minValue is not None: + raise ValueError("Only one of 'minValue' and legacy 'min' may be specified") + if not isinstance(kwargs[_OLD_MIN_OPTION], int | float): + raise ValueError("Legacy 'min' argument must be an integer or float.") + self.minValue = kwargs[_OLD_MIN_OPTION] # type: ignore else: self.minValue = minValue - if _OLD_MAX_OPTION in kwArgs: - assert maxValue is None, \ - "Only one of `maxValue` and `maxValue` can be specified. Use of `maxValue` is preferred" - self.maxValue = kwArgs[_OLD_MAX_OPTION] - kwArgs.pop(_OLD_MAX_OPTION, None) + if _OLD_MAX_OPTION in kwargs: + if maxValue is not None: + raise ValueError("Only one of 'maxValue' and legacy 'max' may be specified") + if not isinstance(kwargs[_OLD_MAX_OPTION], int | float): + raise ValueError("Legacy 'max' argument must be an integer or float.") + self.maxValue = kwargs[_OLD_MAX_OPTION] # type: ignore else: self.maxValue = maxValue - assert len(kwArgs.keys()) == 0, "no keyword options other than `min` and `max` allowed" - assert until is None if self.maxValue is not None else True, "Only one of maxValue or until can be specified" - assert self.maxValue is None if until is not None else True, "Only one of maxValue or until can be specified" + unsupported_kwargs = kwargs.keys() - {_OLD_MIN_OPTION, _OLD_MAX_OPTION} + if len(unsupported_kwargs) > 0: + unexpected = ", ".join(sorted(unsupported_kwargs)) + raise ValueError(f"Unexpected keyword arguments for NRange: {unexpected}") + + if self.maxValue is not None and until is not None: + raise ValueError("Only one of 'maxValue' or 'until' may be specified.") if until is not None: self.maxValue = until + 1 + self.step = step - def _toInitializationDict(self): - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + def _toInitializationDict(self) -> dict[str, object]: + """Convert this `NRange` instance to a dictionary of constructor arguments. + + :return: Dictionary representation of the object. """ - _options = { + _options: dict[str, object] = { "kind": self.__class__.__name__, "minValue": self.minValue, "maxValue": self.maxValue, - "step": self.step + "step": self.step, } return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } - def __str__(self): + def __str__(self) -> str: + """Return a string representation of the numeric range.""" return f"NRange({self.minValue}, {self.maxValue}, {self.step})" - def isEmpty(self): - """Check if object is empty (i.e all instance vars of note are `None` + def isEmpty(self) -> bool: + """Check if the range is empty. + + An `NRange` is considered empty if all of `minValue`, `maxValue`, and `step` are `None`. - :returns: `True` if empty, `False` otherwise + :return: `True` if empty, `False` otherwise. """ return self.minValue is None and self.maxValue is None and self.step is None - def isFullyPopulated(self): - """Check is all instance vars are populated + def isFullyPopulated(self) -> bool: + """Check if all range attributes are populated. - :returns: `True` if fully populated, `False` otherwise + :return: `True` if `minValue`, `maxValue`, and `step` are all not `None`, + `False` otherwise. """ return self.minValue is not None and self.maxValue is not None and self.step is not None - def adjustForColumnDatatype(self, ctype): - """ Adjust default values for column output type + def adjustForColumnDatatype(self, ctype: DataType) -> None: + """Adjust default values for the specified Spark SQL column data type. + + This will: - :param ctype: Spark SQL type instance to adjust range for - :returns: No return value - executes for effect only + - Populate `minValue` and `maxValue` to the default range for the data type + if they are not already set. + - Validate that `maxValue` is within the allowed range for `ByteType` and + `ShortType`. + - Set a default `step` of 1.0 for floating point types and 1 for integral types + if `step` is not already set. + + :param ctype: Spark SQL data type for the column. """ - numeric_types = [DecimalType, FloatType, DoubleType, ByteType, ShortType, IntegerType, LongType] - if type(ctype) in numeric_types: - if self.minValue is None: - self.minValue = NRange._getNumericDataTypeRange(ctype)[0] - if self.maxValue is None: - self.maxValue = NRange._getNumericDataTypeRange(ctype)[1] + numeric_types = (DecimalType, FloatType, DoubleType, ByteType, ShortType, IntegerType, LongType) + + if isinstance(ctype, numeric_types): + numeric_range = NRange._getNumericDataTypeRange(ctype) + if numeric_range is not None: + if self.minValue is None: + self.minValue = numeric_range[0] + if self.maxValue is None: + self.maxValue = numeric_range[1] - if type(ctype) is ShortType and self.maxValue is not None: - assert self.maxValue <= 65536, "`maxValue` must be in range of short" + if isinstance(ctype, ShortType) and self.maxValue is not None: + if self.maxValue > 65536: + raise ValueError("`maxValue` must be within the valid range for ShortType.") - if type(ctype) is ByteType and self.maxValue is not None: - assert self.maxValue <= 256, "`maxValue` must be in range of byte (0 - 256)" + if isinstance(ctype, ByteType) and self.maxValue is not None: + if self.maxValue > 256: + raise ValueError("`maxValue` must be within the valid range (0 - 256) for ByteType.") - if (type(ctype) is DoubleType or type(ctype) is FloatType) and self.step is None: + if isinstance(ctype, (DoubleType, FloatType)) and self.step is None: self.step = 1.0 - if (type(ctype) is ByteType - or type(ctype) is ShortType - or type(ctype) is IntegerType - or type(ctype) is LongType) and self.step is None: + if isinstance(ctype, (ByteType, ShortType, IntegerType, LongType)) and self.step is None: self.step = 1 - def getDiscreteRange(self): - """Convert range to discrete range + def getDiscreteRange(self) -> float: + """Convert range to a discrete range size. + + This is the number of discrete values in the range. For example, `NRange(1, 5, 0.5)` + has 8 discrete values. - :returns: number of discrete values in range. For example `NRange(1, 5, 0.5)` has 8 discrete values + :return: Number of discrete values in the range. + :raises ValueError: If the range is not fully specified or `step` is zero. .. note:: - A range of 0,4, 0.5 has 8 discrete values not 9 as the `maxValue` value is not part of the range + A range of `NRange(0, 4, 0.5)` has 8 discrete values, not 9, as the `maxValue` + value itself is not part of the range. + """ + if self.minValue is None or self.maxValue is None or self.step is None: + raise ValueError("Range must have 'minValue', 'maxValue', and 'step' defined.") - TODO: check range of values + if self.step == 0: + raise ValueError("Parameter 'step' must be non-zero when computing discrete range.") - """ - if type(self.minValue) is int and type(self.maxValue) is int and self.step == 1: - return self.maxValue - self.minValue - else: - # when any component is a float, we will return a float for the discrete range - # to simplify computations - return float(math.floor((self.maxValue - self.minValue) * float(1.0 / self.step))) + if isinstance(self.minValue, int) and isinstance(self.maxValue, int) and self.step == 1: + return float(self.maxValue - self.minValue) - def getContinuousRange(self): - """Convert range to continuous range + # when any component is a float, we will return a float for the discrete range + # to simplify computations + return float(math.floor((self.maxValue - self.minValue) * float(1.0 / self.step))) - :returns: float value for size of interval from `minValue` to `maxValue` + def getContinuousRange(self) -> float: + """Convert range to continuous range. + + :return: Float value for the size of the interval from `minValue` to `maxValue`. + :raises ValueError: If `minValue` or `maxValue` is not defined. """ - return (self.maxValue - self.minValue) * float(1.0) + if self.minValue is None or self.maxValue is None: + raise ValueError("Range must have 'minValue' and 'maxValue' defined.") + + return (self.maxValue - self.minValue) * 1.0 + + def getScale(self) -> int: + """Get the maximum scale (number of decimal places) of the range components. - def getScale(self): - """Get scale of range""" - smin, smax, sstep = 0, 0, 0 + :return: Maximum scale across `minValue`, `maxValue`, and `step`. + """ + smin = 0 + smax = 0 + sstep = 0 if self.minValue is not None: smin = self._precision_and_scale(self.minValue)[1] @@ -163,30 +247,41 @@ def getScale(self): return max(smin, smax, sstep) @staticmethod - def _precision_and_scale(x): + def _precision_and_scale(x: float | int) -> tuple[int, int]: + """Compute precision and scale for a numeric value. + + :param x: Numeric value for which to compute precision and scale. + :return: Tuple of `(precision, scale)`. + """ max_digits = 14 int_part = int(abs(x)) magnitude = 1 if int_part == 0 else int(math.log10(int_part)) + 1 if magnitude >= max_digits: - return (magnitude, 0) + return magnitude, 0 + frac_part = abs(x) - int_part multiplier = 10 ** (max_digits - magnitude) frac_digits = multiplier + int(multiplier * frac_part + 0.5) while frac_digits % 10 == 0: - frac_digits /= 10 + frac_digits //= 10 scale = int(math.log10(frac_digits)) - return (magnitude + scale, scale) + return magnitude + scale, scale @staticmethod - def _getNumericDataTypeRange(ctype): - value_ranges = { - ByteType: (0, (2 ** 4 - 1)), - ShortType: (0, (2 ** 8 - 1)), - IntegerType: (0, (2 ** 16 - 1)), - LongType: (0, (2 ** 32 - 1)), + def _getNumericDataTypeRange(ctype: DataType) -> tuple[float | int, float | int] | None: + """Get the default numeric range for the specified Spark SQL data type. + + :param ctype: Spark SQL data type. + :return: Tuple of `(minValue, maxValue)` for the type, or `None` if not supported. + """ + value_ranges: dict[type, tuple[float | int, float | int]] = { + ByteType: (0, (2**4 - 1)), + ShortType: (0, (2**8 - 1)), + IntegerType: (0, (2**16 - 1)), + LongType: (0, (2**32 - 1)), FloatType: (0.0, 3.402e38), - DoubleType: (0.0, 1.79769e308) + DoubleType: (0.0, 1.79769e308), } - if type(ctype) is DecimalType: + if isinstance(ctype, DecimalType): return 0.0, math.pow(10, ctype.precision - ctype.scale) - 1.0 return value_ranges.get(type(ctype), None) diff --git a/dbldatagen/relation.py b/dbldatagen/relation.py new file mode 100644 index 00000000..5864976e --- /dev/null +++ b/dbldatagen/relation.py @@ -0,0 +1,33 @@ +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module defines the ``ForeignKeyRelation`` class used for describing foreign key relations between datasets. +""" + +from dataclasses import dataclass + +from dbldatagen.datagen_types import ColumnLike +from dbldatagen.utils import ensure_column + + +@dataclass(frozen=True) +class ForeignKeyRelation: + """ + Dataclass describing a foreign key relation between two datasets managed by a ``MultiTableBuilder``. + + :param from_table: Name of the referencing table + :param from_column: Referencing column as a string or ``pyspark.sql.Column`` expression + :param to_table: Name of the referenced table + :param to_column: Referenced column as a string or ``pyspark.sql.Column`` expression + """ + + from_table: str + from_column: ColumnLike + to_table: str + to_column: ColumnLike + + def __post_init__(self) -> None: + object.__setattr__(self, "from_column", ensure_column(self.from_column)) + object.__setattr__(self, "to_column", ensure_column(self.to_column)) diff --git a/dbldatagen/schema_parser.py b/dbldatagen/schema_parser.py index e1b50e4b..955a7fa9 100644 --- a/dbldatagen/schema_parser.py +++ b/dbldatagen/schema_parser.py @@ -8,15 +8,32 @@ import re import pyparsing as pp -from pyspark.sql.types import LongType, FloatType, IntegerType, StringType, DoubleType, BooleanType, ShortType, \ - TimestampType, DateType, DecimalType, ByteType, BinaryType, StructField, StructType, MapType, ArrayType +from pyspark.sql.types import ( + LongType, + FloatType, + IntegerType, + StringType, + DoubleType, + BooleanType, + ShortType, + TimestampType, + DateType, + DecimalType, + ByteType, + BinaryType, + StructField, + StructType, + MapType, + ArrayType, +) class SchemaParser(object): - """ SchemaParser class + """SchemaParser class - Creates pyspark SQL datatype from string + Creates pyspark SQL datatype from string """ + _type_parser = None @classmethod @@ -91,17 +108,33 @@ def getTypeDefinitionParser(cls): # handle decimal types of the form "decimal(10)", "real", "numeric(10,3)" decimal_keyword = pp.MatchFirst( - [pp.CaselessKeyword("decimal"), pp.CaselessKeyword("dec"), pp.CaselessKeyword("number"), - pp.CaselessKeyword("numeric")]) + [ + pp.CaselessKeyword("decimal"), + pp.CaselessKeyword("dec"), + pp.CaselessKeyword("number"), + pp.CaselessKeyword("numeric"), + ] + ) decimal_keyword.setParseAction(lambda s: "decimal") # first number is precision , default 10; second number is scale, default 0 decimal_type_expr = decimal_keyword + pp.Optional( - lbracket + number + pp.Optional(comma + number, "0") + rbracket) - - primitive_type_keyword = (int_keyword ^ bigint_keyword ^ binary_keyword ^ boolean_keyword - ^ date_keyword ^ float_keyword ^ smallint_keyword ^ timestamp_keyword - ^ tinyint_keyword ^ string_type_expr ^ decimal_type_expr ^ double_keyword - ).setName("primitive_type_defn") + lbracket + number + pp.Optional(comma + number, "0") + rbracket + ) + + primitive_type_keyword = ( + int_keyword + ^ bigint_keyword + ^ binary_keyword + ^ boolean_keyword + ^ date_keyword + ^ float_keyword + ^ smallint_keyword + ^ timestamp_keyword + ^ tinyint_keyword + ^ string_type_expr + ^ decimal_type_expr + ^ double_keyword + ).setName("primitive_type_defn") # handle more complex type definitions such as struct, map and array @@ -128,8 +161,12 @@ def getTypeDefinitionParser(cls): # handle structs struct_keyword = pp.CaselessKeyword("struct") - struct_expr = struct_keyword + l_angle + pp.Group( - pp.delimitedList(pp.Group(ident + pp.Optional(colon) + pp.Group(type_expr)))) + r_angle + struct_expr = ( + struct_keyword + + l_angle + + pp.Group(pp.delimitedList(pp.Group(ident + pp.Optional(colon) + pp.Group(type_expr)))) + + r_angle + ) # try to capture invalid type name for better error reporting invalid_type = pp.Word(pp.alphas, pp.alphanums + "_", asKeyword=True) @@ -220,7 +257,7 @@ def _parse_ast(cls, ast): @classmethod def columnTypeFromString(cls, type_string): - """ Generate a Spark SQL data type from a string + """Generate a Spark SQL data type from a string Allowable options for `type_string` parameter are: * `string`, `varchar`, `char`, `nvarchar`, @@ -262,7 +299,7 @@ def columnTypeFromString(cls, type_string): @classmethod def _cleanseSQL(cls, sql_string): - """ Cleanse sql string removing string literals so that they are not considered as part of potential column + """Cleanse sql string removing string literals so that they are not considered as part of potential column references :param sql_string: String representation of SQL expression :returns: cleansed string @@ -290,7 +327,7 @@ def _cleanseSQL(cls, sql_string): @classmethod def columnsReferencesFromSQLString(cls, sql_string, filterItems=None): - """ Generate a list of possible column references from a SQL string + """Generate a list of possible column references from a SQL string This method finds all condidate references to SQL columnn ids in the string @@ -325,20 +362,21 @@ def columnsReferencesFromSQLString(cls, sql_string, filterItems=None): @classmethod def parseCreateTable(cls, sparkSession, source_schema): - """ Parse a schema from a schema string + """Parse a schema from a schema string - :param sparkSession: spark session to use - :param source_schema: should be a table definition minus the create table statement - :returns: Spark SQL schema instance + :param sparkSession: spark session to use + :param source_schema: should be a table definition minus the create table statement + :returns: Spark SQL schema instance """ - assert (source_schema is not None), "`source_schema` must be specified" - assert (sparkSession is not None), "`sparkSession` must be specified" + assert source_schema is not None, "`source_schema` must be specified" + assert sparkSession is not None, "`sparkSession` must be specified" lines = [x.strip() for x in source_schema.split("\n") if x is not None] table_defn = " ".join(lines) # get table name from s - table_name_match = re.search(r"^\s*create\s*(temporary\s*)?table\s*([a-zA-Z0-9_]*)\s*(\(.*)$", table_defn, - flags=re.IGNORECASE) + table_name_match = re.search( + r"^\s*create\s*(temporary\s*)?table\s*([a-zA-Z0-9_]*)\s*(\(.*)$", table_defn, flags=re.IGNORECASE + ) if table_name_match: table_name = table_name_match.group(2) diff --git a/dbldatagen/serialization.py b/dbldatagen/serialization.py index 0bd82d75..fb902855 100644 --- a/dbldatagen/serialization.py +++ b/dbldatagen/serialization.py @@ -1,21 +1,23 @@ -""" Defines interface contracts.""" +"""Defines interface contracts.""" + import sys from typing import TypeVar T = TypeVar("T", bound="SerializableToDict") + class SerializableToDict: - """ Serializable objects must implement a `_getConstructorOptions` method - which converts the object properties to a Python dictionary whose keys - are the named arguments to the class constructor. + """Serializable objects must implement a `_getConstructorOptions` method + which converts the object properties to a Python dictionary whose keys + are the named arguments to the class constructor. """ @classmethod def _fromInitializationDict(cls: type[T], options: dict) -> T: - """ Converts a Python dictionary to an object using the object's constructor. - :param options: Python dictionary with class constructor options - :return: An instance of the class + """Converts a Python dictionary to an object using the object's constructor. + :param options: Python dictionary with class constructor options + :return: An instance of the class """ _options: dict = options.copy() _options.pop("kind") @@ -33,9 +35,10 @@ def _fromInitializationDict(cls: type[T], options: dict) -> T: return cls(**_ir) def _toInitializationDict(self) -> dict: - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + """Converts an object to a Python dictionary. Keys represent the object's + constructor arguments. + :return: Python dictionary representation of the object """ raise NotImplementedError( - f"Object is not serializable. {self.__class__.__name__} does not implement '_toInitializationDict'") + f"Object is not serializable. {self.__class__.__name__} does not implement '_toInitializationDict'" + ) diff --git a/dbldatagen/spark_singleton.py b/dbldatagen/spark_singleton.py index 9680449c..c3d8c0fd 100644 --- a/dbldatagen/spark_singleton.py +++ b/dbldatagen/spark_singleton.py @@ -28,7 +28,9 @@ def getInstance(cls: type["SparkSingleton"]) -> SparkSession: return SparkSession.builder.getOrCreate() @classmethod - def getLocalInstance(cls: type["SparkSingleton"], appName: str = "new Spark session", useAllCores: bool = True) -> SparkSession: + def getLocalInstance( + cls: type["SparkSingleton"], appName: str = "new Spark session", useAllCores: bool = True + ) -> SparkSession: """Creates a machine local `SparkSession` instance for Datalib. By default, it uses `n-1` cores of the available cores for the spark session, where `n` is total cores available. @@ -49,10 +51,11 @@ def getLocalInstance(cls: type["SparkSingleton"], appName: str = "new Spark sess logger = logging.getLogger(__name__) logger.info("Spark core count: %d", spark_core_count) - sparkSession = SparkSession.builder \ - .master(f"local[{spark_core_count}]") \ - .appName(appName) \ - .config("spark.sql.warehouse.dir", "/tmp/spark-warehouse") \ + sparkSession = ( + SparkSession.builder.master(f"local[{spark_core_count}]") + .appName(appName) + .config("spark.sql.warehouse.dir", "/tmp/spark-warehouse") .getOrCreate() + ) return sparkSession diff --git a/dbldatagen/text_generator_plugins.py b/dbldatagen/text_generator_plugins.py index 11cf8c61..ba4759df 100644 --- a/dbldatagen/text_generator_plugins.py +++ b/dbldatagen/text_generator_plugins.py @@ -27,6 +27,7 @@ class _FnCallContext: :param txtGen: - reference to outer PyfnText object """ + textGenerator: "TextGenerator" def __init__(self, txtGen: "TextGenerator") -> None: @@ -84,6 +85,7 @@ class PyfuncText(TextGenerator): # lgtm [py/missing-equals] The code does not guarantee thread or cross process safety. If a new instance of the random number generator is needed, you may call the base class method with the argument `forceNewInstance` set to True. """ + _name: str _initPerBatch: bool _rootProperty: object @@ -98,7 +100,7 @@ def __init__( init: Callable | None = None, initPerBatch: bool = False, name: str | None = None, - rootProperty: object = None + rootProperty: object = None, ) -> None: super().__init__() if not callable(fn): @@ -216,6 +218,7 @@ def initFaker(ctx): rootProperty="faker", name="FakerText")) """ + _name: str _initPerBatch: bool _initFn: Callable | None @@ -268,13 +271,7 @@ def withRootProperty(self, prop: object) -> "PyfuncTextFactory": self._rootProperty = prop return self - def __call__( - self, - evalFn: str | Callable, - *args, - isProperty: bool = False, - **kwargs - ) -> PyfuncText: + def __call__(self, evalFn: str | Callable, *args, isProperty: bool = False, **kwargs) -> PyfuncText: """ Internal function calling mechanism that implements the syntax expansion. @@ -339,7 +336,7 @@ def __init__( providers: list | None = None, name: str = "FakerText", lib: str | None = None, - rootClass: str | None = None + rootClass: str | None = None, ) -> None: super().__init__(name) diff --git a/dbldatagen/text_generators.py b/dbldatagen/text_generators.py index ad7c6e10..68875610 100644 --- a/dbldatagen/text_generators.py +++ b/dbldatagen/text_generators.py @@ -31,12 +31,64 @@ _DIGITS_ZERO = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] #: list of uppercase letters for template generation -_LETTERS_UPPER = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", - "Q", "R", "T", "S", "U", "V", "W", "X", "Y", "Z"] +_LETTERS_UPPER = [ + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "J", + "K", + "L", + "M", + "N", + "O", + "P", + "Q", + "R", + "T", + "S", + "U", + "V", + "W", + "X", + "Y", + "Z", +] #: list of lowercase letters for template generation -_LETTERS_LOWER = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", - "r", "s", "t", "u", "v", "w", "x", "y", "z"] +_LETTERS_LOWER = [ + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", +] #: list of all letters uppercase and lowercase _LETTERS_ALL = _LETTERS_LOWER + _LETTERS_UPPER @@ -48,28 +100,156 @@ _ALNUM_UPPER = _LETTERS_UPPER + _DIGITS_ZERO """ words for ipsum lorem based text generation""" -_WORDS_LOWER = ["lorem", "ipsum", "dolor", "sit", "amet", "consectetur", "adipiscing", "elit", "sed", "do", - "eiusmod", "tempor", "incididunt", "ut", "labore", "et", "dolore", "magna", "aliqua", "ut", - "enim", "ad", "minim", "veniam", "quis", "nostrud", "exercitation", "ullamco", "laboris", - "nisi", "ut", "aliquip", "ex", "ea", "commodo", "consequat", "duis", "aute", "irure", "dolor", - "in", "reprehenderit", "in", "voluptate", "velit", "esse", "cillum", "dolore", "eu", "fugiat", - "nulla", "pariatur", "excepteur", "sint", "occaecat", "cupidatat", "non", "proident", "sunt", - "in", "culpa", "qui", "officia", "deserunt", "mollit", "anim", "id", "est", "laborum"] - -_WORDS_UPPER = ["LOREM", "IPSUM", "DOLOR", "SIT", "AMET", "CONSECTETUR", "ADIPISCING", "ELIT", "SED", "DO", - "EIUSMOD", "TEMPOR", "INCIDIDUNT", "UT", "LABORE", "ET", "DOLORE", "MAGNA", "ALIQUA", "UT", - "ENIM", "AD", "MINIM", "VENIAM", "QUIS", "NOSTRUD", "EXERCITATION", "ULLAMCO", "LABORIS", - "NISI", "UT", "ALIQUIP", "EX", "EA", "COMMODO", "CONSEQUAT", "DUIS", "AUTE", "IRURE", - "DOLOR", "IN", "REPREHENDERIT", "IN", "VOLUPTATE", "VELIT", "ESSE", "CILLUM", "DOLORE", - "EU", "FUGIAT", "NULLA", "PARIATUR", "EXCEPTEUR", "SINT", "OCCAECAT", "CUPIDATAT", "NON", - "PROIDENT", "SUNT", "IN", "CULPA", "QUI", "OFFICIA", "DESERUNT", "MOLLIT", "ANIM", "ID", "EST", - "LABORUM"] +_WORDS_LOWER = [ + "lorem", + "ipsum", + "dolor", + "sit", + "amet", + "consectetur", + "adipiscing", + "elit", + "sed", + "do", + "eiusmod", + "tempor", + "incididunt", + "ut", + "labore", + "et", + "dolore", + "magna", + "aliqua", + "ut", + "enim", + "ad", + "minim", + "veniam", + "quis", + "nostrud", + "exercitation", + "ullamco", + "laboris", + "nisi", + "ut", + "aliquip", + "ex", + "ea", + "commodo", + "consequat", + "duis", + "aute", + "irure", + "dolor", + "in", + "reprehenderit", + "in", + "voluptate", + "velit", + "esse", + "cillum", + "dolore", + "eu", + "fugiat", + "nulla", + "pariatur", + "excepteur", + "sint", + "occaecat", + "cupidatat", + "non", + "proident", + "sunt", + "in", + "culpa", + "qui", + "officia", + "deserunt", + "mollit", + "anim", + "id", + "est", + "laborum", +] + +_WORDS_UPPER = [ + "LOREM", + "IPSUM", + "DOLOR", + "SIT", + "AMET", + "CONSECTETUR", + "ADIPISCING", + "ELIT", + "SED", + "DO", + "EIUSMOD", + "TEMPOR", + "INCIDIDUNT", + "UT", + "LABORE", + "ET", + "DOLORE", + "MAGNA", + "ALIQUA", + "UT", + "ENIM", + "AD", + "MINIM", + "VENIAM", + "QUIS", + "NOSTRUD", + "EXERCITATION", + "ULLAMCO", + "LABORIS", + "NISI", + "UT", + "ALIQUIP", + "EX", + "EA", + "COMMODO", + "CONSEQUAT", + "DUIS", + "AUTE", + "IRURE", + "DOLOR", + "IN", + "REPREHENDERIT", + "IN", + "VOLUPTATE", + "VELIT", + "ESSE", + "CILLUM", + "DOLORE", + "EU", + "FUGIAT", + "NULLA", + "PARIATUR", + "EXCEPTEUR", + "SINT", + "OCCAECAT", + "CUPIDATAT", + "NON", + "PROIDENT", + "SUNT", + "IN", + "CULPA", + "QUI", + "OFFICIA", + "DESERUNT", + "MOLLIT", + "ANIM", + "ID", + "EST", + "LABORUM", +] class TextGenerator(ABC): """ Base class for all text generation classes. """ + _randomSeed: int _rngInstance: numpy.random.Generator | None @@ -117,8 +297,11 @@ def getNPRandomGenerator(self, forceNewInstance: bool = False) -> numpy.random.G :return: Random number generator initialized from previously supplied random seed. """ - assert self._randomSeed is None or type(self._randomSeed) in [int, np.int32, np.int64], \ - f"`random_seed` must be int or int-like not {type(self._randomSeed)}" + assert self._randomSeed is None or type(self._randomSeed) in [ + int, + np.int32, + np.int64, + ], f"`random_seed` must be int or int-like not {type(self._randomSeed)}" if self._rngInstance is not None and not forceNewInstance: return self._rngInstance @@ -157,24 +340,23 @@ def compactNumpyTypeForValues(listValues: list | numpy.ndarray) -> np.dtype: @staticmethod def getAsTupleOrElse( - v: int | tuple[int, int] | None, - defaultValue: tuple[int, int], - valueName: str = "value" + v: int | tuple[int, int] | None, defaultValue: tuple[int, int], valueName: str = "value" ) -> tuple[int, int]: - """ get value v as tuple or return default value + """get value v as tuple or return default value - :param v: value to test - :param defaultValue: value to use as a default if value of `v` is None. Must be a tuple. - :param valueName: name of value for debugging and logging purposes - :returns: return `v` as tuple if not `None` or value of `default_v` if `v` is `None`. If `v` is a single - value, returns the tuple (`v`, `v`)""" + :param v: value to test + :param defaultValue: value to use as a default if value of `v` is None. Must be a tuple. + :param valueName: name of value for debugging and logging purposes + :returns: return `v` as tuple if not `None` or value of `default_v` if `v` is `None`. If `v` is a single + value, returns the tuple (`v`, `v`)""" assert not v or isinstance(v, int | tuple), f"param {valueName} must be an int, a tuple or None" assert isinstance(defaultValue, tuple) and len(defaultValue) == 2, "default value must be tuple" if not v: assert len(defaultValue) == 2, "must have list or iterable with lenght 2" - assert isinstance(defaultValue[0], int) and isinstance(defaultValue[1], int), \ - "all elements must be integers" + assert isinstance(defaultValue[0], int) and isinstance( + defaultValue[1], int + ), "all elements must be integers" return defaultValue if isinstance(v, tuple): @@ -278,10 +460,7 @@ class TemplateGenerator(TextGenerator, SerializableToDict): # lgtm [py/missing- _templateEscapedMappings: dict[str, tuple[int, np.ndarray | None]] def __init__( - self, - template: str, - escapeSpecialChars: bool = False, - extendedWordList: list[str] | None = None + self, template: str, escapeSpecialChars: bool = False, extendedWordList: list[str] | None = None ) -> None: super().__init__() @@ -291,8 +470,9 @@ def __init__( self._escapeSpecialMeaning = bool(escapeSpecialChars) self._templates = self._splitTemplates(self._template) self._wordList = np.array(extendedWordList if extendedWordList is not None else _WORDS_LOWER) - self._upperWordList = np.array([x.upper() for x in extendedWordList] - if extendedWordList is not None else _WORDS_UPPER) + self._upperWordList = np.array( + [x.upper() for x in extendedWordList] if extendedWordList is not None else _WORDS_UPPER + ) self._np_digits_zero = np.array(_DIGITS_ZERO) self._np_digits_non_zero = np.array(_DIGITS_NON_ZERO) @@ -314,7 +494,7 @@ def __init__( "d": (10, self._np_digits_zero), "D": (9, self._np_digits_non_zero), "k": (36, self._np_alnum_lower), - "K": (36, self._np_alnum_upper) + "K": (36, self._np_alnum_upper), } # ensure that each mapping is mapping from string to list or numpy array @@ -323,15 +503,14 @@ def __init__( assert v and isinstance(v, tuple) and len(v) == 2, "value must be tuple of length 2" mapping_length, mappings = v assert isinstance(mapping_length, int), "mapping length must be of type int" - assert isinstance(mappings, list | np.ndarray), \ - "mappings are lists or numpy arrays" + assert isinstance(mappings, list | np.ndarray), "mappings are lists or numpy arrays" assert mapping_length == 0 or len(mappings) == mapping_length, "mappings must match mapping_length" self._templateEscapedMappings = { "n": (256, None), "N": (65536, None), "w": (self._lenWords, self._wordList), - "W": (self._lenWords, self._upperWordList) + "W": (self._lenWords, self._upperWordList), } # ensure that each escaped mapping is mapping from string to None, list or numpy array @@ -340,8 +519,7 @@ def __init__( assert v and isinstance(v, tuple) and len(v) == 2, "value must be tuple of length 2" mapping_length, mappings = v assert isinstance(mapping_length, int), "mapping length must be of type int" - assert mappings is None or isinstance(mappings, list | np.ndarray), \ - "mappings are lists or numpy arrays" + assert mappings is None or isinstance(mappings, list | np.ndarray), "mappings are lists or numpy arrays" # for escaped mappings, the mapping can be None in which case the mapping is to the number itself # i.e mapping[4] = 4 @@ -350,8 +528,10 @@ def __init__( # get the template metadata - this will be list of metadata entries for each template # for each template, metadata will be tuple of number of placeholders followed by list of random bounds # to be computed when replacing non static placeholder - template_info = [self._prepareTemplateStrings(template, escapeSpecialMeaning=escapeSpecialChars) - for template in self._templates] + template_info = [ + self._prepareTemplateStrings(template, escapeSpecialMeaning=escapeSpecialChars) + for template in self._templates + ] self._max_placeholders = max(x[0] for x in template_info) self._max_rnds_needed = max(len(x[1]) for x in template_info) @@ -363,7 +543,7 @@ def __repr__(self) -> str: @property def templates(self) -> list[str]: - """ Get effective templates for text generator""" + """Get effective templates for text generator""" return self._templates def classicGenerateText(self, v: str) -> str: @@ -410,11 +590,7 @@ def pandasGenerateText(self, v: pd.Series) -> pd.Series: # expand values into placeholders without affect masked values self._applyTemplateStringsForTemplate( - v, - self._templates[x], - masked_placeholders, - masked_rnds, - escapeSpecialMeaning=self._escapeSpecialMeaning + v, self._templates[x], masked_placeholders, masked_rnds, escapeSpecialMeaning=self._escapeSpecialMeaning ) # soften and clear mask, allowing modifications @@ -429,20 +605,20 @@ def pandasGenerateText(self, v: pd.Series) -> pd.Series: return results def _toInitializationDict(self) -> dict[str, Any]: - """ Converts an object to a Python dictionary. Keys represent the object's - constructor arguments. - :return: Python dictionary representation of the object + """Converts an object to a Python dictionary. Keys represent the object's + constructor arguments. + :return: Python dictionary representation of the object """ _options = { "kind": self.__class__.__name__, "template": self._template, "escapeSpecialChars": self._escapeSpecialChars, - "extendedWordList": self._extendedWordList + "extendedWordList": self._extendedWordList, } return { - k: v._toInitializationDict() - if isinstance(v, SerializableToDict) else v - for k, v in _options.items() if v is not None + k: v._toInitializationDict() if isinstance(v, SerializableToDict) else v + for k, v in _options.items() + if v is not None } @staticmethod @@ -552,13 +728,13 @@ def _prepareTemplateStrings(self, genTemplate: str, escapeSpecialMeaning: bool = return num_placeholders, retval def _applyTemplateStringsForTemplate( - self, - baseValue: pd.Series | pd.DataFrame, - genTemplate: str, - placeholders: np.ndarray, - rnds: np.ndarray, - *, - escapeSpecialMeaning: bool = False + self, + baseValue: pd.Series | pd.DataFrame, + genTemplate: str, + placeholders: np.ndarray, + rnds: np.ndarray, + *, + escapeSpecialMeaning: bool = False, ) -> np.ndarray: """ Vectorized implementation of template driven text substitution. Applies substitutions to placeholders using @@ -677,7 +853,7 @@ def _get_values_subelement(elem: int) -> np.ndarray: placeholders[unmasked_rows, num_placeholders] = value_mappings[rnds[unmasked_rows, rnd_offset]] else: placeholders[:, num_placeholders] = value_mappings[rnds[:, rnd_offset]] - elif unmasked_rows is not None: # type: ignore[unreachable] + elif unmasked_rows is not None: # type: ignore[unreachable] placeholders[unmasked_rows, num_placeholders] = rnds[unmasked_rows, rnd_offset] else: placeholders[:, num_placeholders] = rnds[:, rnd_offset] @@ -769,14 +945,15 @@ class ILText(TextGenerator, SerializableToDict): # lgtm [py/missing-equals] """ def __init__( - self, - paragraphs: int | tuple[int, int] | None = None, - sentences: int | tuple[int, int] | None = None, - words: int | tuple[int, int] | None = None, - extendedWordList: list[str] | None = None + self, + paragraphs: int | tuple[int, int] | None = None, + sentences: int | tuple[int, int] | None = None, + words: int | tuple[int, int] | None = None, + extendedWordList: list[str] | None = None, ) -> None: - assert paragraphs is not None or sentences is not None or words is not None, \ - "At least one of the params `paragraphs`, `sentences` or `words` must be specified" + assert ( + paragraphs is not None or sentences is not None or words is not None + ), "At least one of the params `paragraphs`, `sentences` or `words` must be specified" super().__init__() @@ -860,9 +1037,9 @@ def generateText(self, baseValues: list | pd.Series | np.ndarray, rowCount: int masked_offsets: np.ma.MaskedArray = np.ma.MaskedArray(word_offsets, mask=final_mask) # note numpy random differs from standard random in that it never produces upper bound - masked_offsets[~masked_offsets.mask] = rng.integers(self._wordOffsetSize, - size=output_shape, - dtype=self._wordOffsetType)[~masked_offsets.mask] + masked_offsets[~masked_offsets.mask] = rng.integers( + self._wordOffsetSize, size=output_shape, dtype=self._wordOffsetType + )[~masked_offsets.mask] # hardening a mask prevents masked values from being changed np.ma.harden_mask(masked_offsets) @@ -956,7 +1133,7 @@ def _toInitializationDict(self) -> dict[str, Any]: "paragraphs": self._paragraphs, "sentences": self._sentences, "words": self._words, - "extendedWordList": self._extendedWordList + "extendedWordList": self._extendedWordList, } return _options diff --git a/dbldatagen/utils.py b/dbldatagen/utils.py index e7acaa2b..60e0913f 100644 --- a/dbldatagen/utils.py +++ b/dbldatagen/utils.py @@ -18,10 +18,12 @@ from typing import Any import jmespath -from pyspark.sql import DataFrame +from pyspark.sql import Column, DataFrame +from pyspark.sql.functions import col from pyspark.sql.streaming.query import StreamingQuery from dbldatagen.config import OutputDataset +from dbldatagen.datagen_types import ColumnLike def deprecated(message: str = "") -> Callable[[Callable[..., Any]], Callable[..., Any]]: @@ -36,8 +38,11 @@ def deprecated(message: str = "") -> Callable[[Callable[..., Any]], Callable[... def deprecated_decorator(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) def deprecated_func(*args: object, **kwargs: object) -> object: - warnings.warn(f"`{func.__name__}` is a deprecated function or method. \n{message}", - category=DeprecationWarning, stacklevel=1) + warnings.warn( + f"`{func.__name__}` is a deprecated function or method. \n{message}", + category=DeprecationWarning, + stacklevel=1, + ) warnings.simplefilter("default", DeprecationWarning) return func(*args, **kwargs) @@ -49,13 +54,12 @@ def deprecated_func(*args: object, **kwargs: object) -> object: class DataGenError(Exception): """Exception class to represent data generation errors - :param msg: message related to error - :param baseException: underlying exception, if any that caused the issue + :param msg: message related to error + :param baseException: underlying exception, if any that caused the issue """ def __init__(self, msg: str, baseException: object | None = None) -> None: - """ constructor - """ + """constructor""" super().__init__(msg) self._underlyingException: object | None = baseException self._msg: str = msg @@ -96,11 +100,11 @@ def strip_margin(text: str) -> str: def mkBoundsList(x: int | list[int] | None, default: int | list[int]) -> tuple[bool, list[int]]: - """ make a bounds list from supplied parameter - otherwise use default + """make a bounds list from supplied parameter - otherwise use default - :param x: integer or list of 2 values that define bounds list - :param default: default value if X is `None` - :returns: list of form [x,y] + :param x: integer or list of 2 values that define bounds list + :param default: default value if X is `None` + :returns: list of form [x,y] """ if x is None: retval = (True, [default, default]) if isinstance(default, int) else (True, list(default)) @@ -116,11 +120,9 @@ def mkBoundsList(x: int | list[int] | None, default: int | list[int]) -> tuple[b def topologicalSort( - sources: list[tuple[str, set[str]]], - initial_columns: list[str] | None = None, - flatten: bool = True + sources: list[tuple[str, set[str]]], initial_columns: list[str] | None = None, flatten: bool = True ) -> list[str] | list[list[str]]: - """ Perform a topological sort over sources + """Perform a topological sort over sources Used to compute the column test data generation order of the column generation dependencies. @@ -242,7 +244,7 @@ def parse_time_interval(spec: str) -> timedelta: milliseconds=milliseconds, minutes=minutes, hours=hours, - weeks=weeks + (years * _WEEKS_PER_YEAR) + weeks=weeks + (years * _WEEKS_PER_YEAR), ) return delta @@ -273,7 +275,7 @@ def strip_margins(s: str, marginChar: str) -> str: for line in lines: if marginChar in line: - revised_line: str = line[line.index(marginChar) + 1:] + revised_line: str = line[line.index(marginChar) + 1 :] revised_lines.append(revised_line) else: revised_lines.append(line) @@ -324,8 +326,8 @@ def match_condition(matchList: list[Any], matchFn: Callable[[Any], bool]) -> int ix: int = match_condition(lst, cond) if ix != -1: retval.extend(split_list_matching_condition(lst[0:ix], cond)) - retval.append(lst[ix:ix + 1]) - retval.extend(split_list_matching_condition(lst[ix + 1:], cond)) + retval.append(lst[ix : ix + 1]) + retval.extend(split_list_matching_condition(lst[ix + 1 :], cond)) else: retval = [lst] @@ -334,7 +336,7 @@ def match_condition(matchList: list[Any], matchFn: Callable[[Any], bool]) -> int def json_value_from_path(searchPath: str, jsonData: str, defaultValue: object) -> object: - """ Get JSON value from JSON data referenced by searchPath + """Get JSON value from JSON data referenced by searchPath searchPath should be a JSON path as supported by the `jmespath` package (see https://jmespath.org/) @@ -358,7 +360,7 @@ def json_value_from_path(searchPath: str, jsonData: str, defaultValue: object) - def system_time_millis() -> int: - """ return system time as milliseconds since start of epoch + """return system time as milliseconds since start of epoch :return: system time millis as long """ @@ -401,3 +403,17 @@ def write_data_to_output(df: DataFrame, output_dataset: OutputDataset) -> Stream ) return None + + +def ensure_column(column: ColumnLike) -> Column: + """ + Normalizes an input ``ColumnLike`` value into a Column expression. + + :param column: Column name as a string or a ``pyspark.sql.Column`` expression + :return: ``pyspark.sql.Column`` expression + """ + if isinstance(column, Column): + return column + + if isinstance(column, str): + return col(column) diff --git a/docs/source/conf.py b/docs/source/conf.py index 95998dd7..6580a2d2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -45,13 +45,10 @@ # 'numpydoc', # handle NumPy documentation formatted docstrings. Needs to install 'recommonmark', # allow including Commonmark markdown in sources 'sphinx_rtd_theme', - 'sphinx_copybutton' + 'sphinx_copybutton', ] -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown' -} +source_suffix = {'.rst': 'restructuredtext', '.md': 'markdown'} pdf_documents = [ ("index", project, project, author), @@ -107,9 +104,7 @@ # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] -html_css_files = [ - 'css/tdg.css' -] +html_css_files = ['css/tdg.css'] # html_sidebars={ # '**' : [ 'globaltoc.html'] diff --git a/docs/utils/mk_quick_index.py b/docs/utils/mk_quick_index.py index db89ab9b..80899e96 100644 --- a/docs/utils/mk_quick_index.py +++ b/docs/utils/mk_quick_index.py @@ -17,90 +17,70 @@ """ SOURCE_FILES = { - "data_generator.py": {"briefDesc": "Main generator classes", - "grouping": "main classes"}, - "column_generation_spec.py": {"briefDesc": "Column Generation Spec types", - "grouping": "internal classes"}, - "column_spec_options.py": {"briefDesc": "Column Generation Options", - "grouping": "main classes"}, - "datarange.py": {"briefDesc": "Internal data range abstract types", - "grouping": "main classes"}, - "daterange.py": {"briefDesc": "Date and time ranges", - "grouping": "main classes"}, - - "datasets_object.py": {"briefDesc": "Entry point for standard datasets", - "grouping": "main classes"}, - - "nrange.py": {"briefDesc": "Numeric ranges", - "grouping": "main classes"}, - "text_generators.py": {"briefDesc": "Text data generation", - "grouping": "main classes"}, - "text_generator_plugins.py": {"briefDesc": "Text data generation", - "grouping": "main classes"}, - "data_analyzer.py": {"briefDesc": "Analysis of existing data", - "grouping": "main classes"}, - "function_builder.py": {"briefDesc": "Internal utilities to create functions related to weights", - "grouping": "internal classes"}, - "schema_parser.py": {"briefDesc": "Internal utilities to parse Spark SQL schema information", - "grouping": "internal classes"}, - "spark_singleton.py": {"briefDesc": "Spark singleton for test purposes", - "grouping": "internal classes"}, - "utils.py": {"briefDesc": "", - "grouping": "internal classes"}, - "html_utils.py": {"briefDesc": "", - "grouping": "internal classes"}, - - "beta.py": {"briefDesc": "Beta distribution related code", - "grouping": "data distribution"}, - "data_distribution.py": {"briefDesc": "Data distribution related code", - "grouping": "data distribution"}, - "normal_distribution.py": {"briefDesc": "Normal data distribution related code", - "grouping": "data distribution"}, - "gamma.py": {"briefDesc": "Gamma data distribution related code", - "grouping": "data distribution"}, - "exponential_distribution.py": {"briefDesc": "Exponential data distribution related code", - "grouping": "data distribution"}, - - "basic_user.py": {"briefDesc": "Provider for `basic/user` standard dataset", - "grouping": "Standard datasets"}, - "dataset_provider.py": {"briefDesc": "Base class for standard dataset providers", - "grouping": "Standard datasets"}, - "multi_table_telephony_provider.py": {"briefDesc": "Provider for `multi-table/telephony` standard dataset", - "grouping": "Standard datasets"}, - "constraint.py": {"briefDesc": "Constraint related code", - "grouping": "data generation constraints"}, - "chained_relation.py": {"briefDesc": "ChainedInequality constraint related code", - "grouping": "data generation constraints"}, - "value_multiple_constraint.py": {"briefDesc": "FixedIncrement constraint related code", - "grouping": "data generation constraints"}, - "negative_values.py": {"briefDesc": "Negative constraint related code", - "grouping": "data generation constraints"}, - "positive_values.py": {"briefDesc": "Positive constraint related code", - "grouping": "data generation constraints"}, - "literal_relation_constraint.py": {"briefDesc": "Scalar inequality constraint related code", - "grouping": "data generation constraints"}, - "literal_range_constraint.py": {"briefDesc": "ScalarRange constraint related code", - "grouping": "data generation constraints"}, - "sql_expr.py": {"briefDesc": "SQL expression constraint related code", - "grouping": "data generation constraints"}, - + "data_generator.py": {"briefDesc": "Main generator classes", "grouping": "main classes"}, + "column_generation_spec.py": {"briefDesc": "Column Generation Spec types", "grouping": "internal classes"}, + "column_spec_options.py": {"briefDesc": "Column Generation Options", "grouping": "main classes"}, + "datarange.py": {"briefDesc": "Internal data range abstract types", "grouping": "main classes"}, + "daterange.py": {"briefDesc": "Date and time ranges", "grouping": "main classes"}, + "datasets_object.py": {"briefDesc": "Entry point for standard datasets", "grouping": "main classes"}, + "nrange.py": {"briefDesc": "Numeric ranges", "grouping": "main classes"}, + "text_generators.py": {"briefDesc": "Text data generation", "grouping": "main classes"}, + "text_generator_plugins.py": {"briefDesc": "Text data generation", "grouping": "main classes"}, + "data_analyzer.py": {"briefDesc": "Analysis of existing data", "grouping": "main classes"}, + "function_builder.py": { + "briefDesc": "Internal utilities to create functions related to weights", + "grouping": "internal classes", + }, + "schema_parser.py": { + "briefDesc": "Internal utilities to parse Spark SQL schema information", + "grouping": "internal classes", + }, + "spark_singleton.py": {"briefDesc": "Spark singleton for test purposes", "grouping": "internal classes"}, + "utils.py": {"briefDesc": "", "grouping": "internal classes"}, + "html_utils.py": {"briefDesc": "", "grouping": "internal classes"}, + "beta.py": {"briefDesc": "Beta distribution related code", "grouping": "data distribution"}, + "data_distribution.py": {"briefDesc": "Data distribution related code", "grouping": "data distribution"}, + "normal_distribution.py": {"briefDesc": "Normal data distribution related code", "grouping": "data distribution"}, + "gamma.py": {"briefDesc": "Gamma data distribution related code", "grouping": "data distribution"}, + "exponential_distribution.py": { + "briefDesc": "Exponential data distribution related code", + "grouping": "data distribution", + }, + "basic_user.py": {"briefDesc": "Provider for `basic/user` standard dataset", "grouping": "Standard datasets"}, + "dataset_provider.py": {"briefDesc": "Base class for standard dataset providers", "grouping": "Standard datasets"}, + "multi_table_telephony_provider.py": { + "briefDesc": "Provider for `multi-table/telephony` standard dataset", + "grouping": "Standard datasets", + }, + "constraint.py": {"briefDesc": "Constraint related code", "grouping": "data generation constraints"}, + "chained_relation.py": { + "briefDesc": "ChainedInequality constraint related code", + "grouping": "data generation constraints", + }, + "value_multiple_constraint.py": { + "briefDesc": "FixedIncrement constraint related code", + "grouping": "data generation constraints", + }, + "negative_values.py": {"briefDesc": "Negative constraint related code", "grouping": "data generation constraints"}, + "positive_values.py": {"briefDesc": "Positive constraint related code", "grouping": "data generation constraints"}, + "literal_relation_constraint.py": { + "briefDesc": "Scalar inequality constraint related code", + "grouping": "data generation constraints", + }, + "literal_range_constraint.py": { + "briefDesc": "ScalarRange constraint related code", + "grouping": "data generation constraints", + }, + "sql_expr.py": {"briefDesc": "SQL expression constraint related code", "grouping": "data generation constraints"}, } # grouping metadata information # note that the GROUPING_INFO will be output in the order that they appear here GROUPING_INFO = { - "main classes": { - "heading": "Main user facing classes, functions and types" - }, - "internal classes": { - "heading": "Internal classes, functions and types" - }, - "data distribution": { - "heading": "Data distribution related classes, functions and types" - }, - "data generation constraints": { - "heading": "Data generation constraints related classes, functions and types" - } + "main classes": {"heading": "Main user facing classes, functions and types"}, + "internal classes": {"heading": "Internal classes, functions and types"}, + "data distribution": {"heading": "Data distribution related classes, functions and types"}, + "data generation constraints": {"heading": "Data generation constraints related classes, functions and types"}, } PACKAGE_NAME = "dbldatagen" @@ -108,7 +88,7 @@ def writeUnderlined(outputFile, text, underline="="): - """ write underlined text in RST markup format + """write underlined text in RST markup format :param outputFile: output file to write to :param text: text to write @@ -131,7 +111,7 @@ class FileMeta: BRIEF_DESCRIPTION = "briefDesc" def __init__(self, moduleName: str, metadata, classes, functions, types, subpackage: str = None): - """ Constructor for File Meta + """Constructor for File Meta :param moduleName: :param metadata: @@ -153,14 +133,16 @@ def __init__(self, moduleName: str, metadata, classes, functions, types, subpack @property def isPopulated(self): - """ Check if instance has any classes, functions or types""" - return ((self.classes is not None and len(self.classes) > 0) - or (self.functions is not None and len(self.functions) > 0) - or (self.types is not None and len(self.types) > 0)) + """Check if instance has any classes, functions or types""" + return ( + (self.classes is not None and len(self.classes) > 0) + or (self.functions is not None and len(self.functions) > 0) + or (self.types is not None and len(self.types) > 0) + ) def findMembers(sourceFile, fileMetadata, fileSubpackage): - """ Find classes, types and functions in file + """Find classes, types and functions in file :param fileMetadata: metadata for file :param fileSubpackage: subpackage for file @@ -194,12 +176,14 @@ def findMembers(sourceFile, fileMetadata, fileSubpackage): except Exception as e: print(f"*** failed to process file: {fname}") - return FileMeta(moduleName=Path(sourceFile.name).stem, - metadata=fileMetadata, - subpackage=fileSubpackage, - classes=sorted(classes), - functions=sorted(functions), - types=sorted(types)) + return FileMeta( + moduleName=Path(sourceFile.name).stem, + metadata=fileMetadata, + subpackage=fileSubpackage, + classes=sorted(classes), + functions=sorted(functions), + types=sorted(types), + ) def includeTemplate(outputFile): @@ -215,7 +199,7 @@ def includeTemplate(outputFile): def processItemList(outputFile, items, sectionTitle, subpackage=None): - """ process list of items + """process list of items :param outputFile: output file instance :param items: list of items. each item is a tuple of ( "moduleName.typename", "type description") @@ -240,7 +224,7 @@ def processItemList(outputFile, items, sectionTitle, subpackage=None): def processDirectory(outputFile, pathToProcess, subpackage=None): - """ process directory for package or subpackage + """process directory for package or subpackage :param outputFile: output file instance :param pathToProcess: path to process @@ -261,9 +245,7 @@ def processDirectory(outputFile, pathToProcess, subpackage=None): title = SOURCE_FILES[relativeFile.name]["briefDesc"] # get the classes, functions and types for the file - fileMetaInfo = findMembers(fp, - fileMetadata=SOURCE_FILES[relativeFile.name], - fileSubpackage=subpackage) + fileMetaInfo = findMembers(fp, fileMetadata=SOURCE_FILES[relativeFile.name], fileSubpackage=subpackage) assert fileMetaInfo is not None @@ -274,7 +256,9 @@ def processDirectory(outputFile, pathToProcess, subpackage=None): assert type(fileGroupings[fileMetaInfo.grouping]) is list fileGroupings[fileMetaInfo.grouping].append(fileMetaInfo) else: - newEntry = [fileMetaInfo, ] + newEntry = [ + fileMetaInfo, + ] assert type(newEntry) is list fileGroupings[fileMetaInfo.grouping] = newEntry elif not relativeFile.name.startswith("_"): @@ -288,30 +272,33 @@ def processDirectory(outputFile, pathToProcess, subpackage=None): fileMetaInfoList = fileGroupings[grp] # get the list of classes for the package - classList = [(f"{fileMetaInfo.moduleName}.{cls}", fileMetaInfo.description) - for fileMetaInfo in fileMetaInfoList for cls in fileMetaInfo.classes - if fileMetaInfo.isPopulated] + classList = [ + (f"{fileMetaInfo.moduleName}.{cls}", fileMetaInfo.description) + for fileMetaInfo in fileMetaInfoList + for cls in fileMetaInfo.classes + if fileMetaInfo.isPopulated + ] # get the list of functions for the package - functionList = [(f"{fileMetaInfo.moduleName}.{fn}", fileMetaInfo.description) - for fileMetaInfo in fileMetaInfoList for fn in fileMetaInfo.functions - if fileMetaInfo.isPopulated] + functionList = [ + (f"{fileMetaInfo.moduleName}.{fn}", fileMetaInfo.description) + for fileMetaInfo in fileMetaInfoList + for fn in fileMetaInfo.functions + if fileMetaInfo.isPopulated + ] # get the list of types for the package - typeList = [(f"{fileMetaInfo.moduleName}.{typ}", "") - for fileMetaInfo in fileMetaInfoList for typ in fileMetaInfo.types - if fileMetaInfo.isPopulated] + typeList = [ + (f"{fileMetaInfo.moduleName}.{typ}", "") + for fileMetaInfo in fileMetaInfoList + for typ in fileMetaInfo.types + if fileMetaInfo.isPopulated + ] # emit each of the sections in the index - processItemList(outputFile, classList, - sectionTitle="Classes", - subpackage=subpackage) - processItemList(outputFile, functionList, - sectionTitle="Functions", - subpackage=subpackage) - processItemList(outputFile, typeList, - sectionTitle="Types", - subpackage=subpackage) + processItemList(outputFile, classList, sectionTitle="Classes", subpackage=subpackage) + processItemList(outputFile, functionList, sectionTitle="Functions", subpackage=subpackage) + processItemList(outputFile, typeList, sectionTitle="Types", subpackage=subpackage) def main(dirToSearch, outputPath): diff --git a/docs/utils/mk_requirements.py b/docs/utils/mk_requirements.py index a20588b1..d45f763e 100644 --- a/docs/utils/mk_requirements.py +++ b/docs/utils/mk_requirements.py @@ -8,10 +8,14 @@ with open(sys.argv[2], "w") as writer: writer.write("# Dependencies for the Data Generator framework\n\n") - writer.write("Building the framework will install a number of packages in the virtual environment used for building the framework.\n") - writer.write("Some of these are required at runtime, while a number of packages are used for building the documentation only.\n\n") + writer.write( + "Building the framework will install a number of packages in the virtual environment used for building the framework.\n" + ) + writer.write( + "Some of these are required at runtime, while a number of packages are used for building the documentation only.\n\n" + ) for line in lines: - line =line.strip() + line = line.strip() if line.startswith("#"): line = line[1:] elif not len(line) == 0: diff --git a/pyproject.toml b/pyproject.toml index aab70aa9..2f6ec796 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,11 @@ description = "Databricks Labs - PySpark Synthetic Data Generator" authors = [ {name = "Databricks Labs", email = "labs-oss@databricks.com"}, {name = "Ronan Stokes", email = "ronan.stokes@databricks.com"}, +] +maintainers = [ + {name = "Daniel Tomes", email = "daniel.tomes@databricks.com"}, {name = "Greg Hansen", email = "gregory.hansen@databricks.com"}, + {name = "Anup Kalburgi", email = "anup.kalburgi@databricks.com"} ] readme = "README.md" license = {text = "Databricks License"} @@ -79,6 +83,7 @@ include = [ [tool.hatch.envs.default] dependencies = [ + "black~=25.12.0", "chispa~=0.10.1", "coverage[toml]~=7.4.4", "mypy~=1.9.0", @@ -112,13 +117,21 @@ path = ".venv" [tool.hatch.envs.default.scripts] test = "pytest tests/ -n 10 --cov --cov-report=html --timeout 600 --durations 20" -fmt = ["ruff check . --fix", +fmt = ["black .", + "ruff check . --fix", "mypy .", "pylint --output-format=colorized -j 0 dbldatagen tests"] -verify = ["ruff check .", +verify = ["black --check .", + "ruff check .", "mypy .", "pylint --output-format=colorized -j 0 dbldatagen tests"] +[tool.black] +target-version = ["py310"] +line-length = 120 +skip-string-normalization = true +extend-exclude = "examples" + # Ruff configuration - replaces flake8, isort, pydocstyle, etc. [tool.ruff] target-version = "py310" @@ -137,15 +150,12 @@ exclude = [ "tutorial", "examples", "tests", - "dbldatagen/constraints", - "dbldatagen/distributions", "dbldatagen/__init__.py", + "dbldatagen/constraints/__init__.py", + "dbldatagen/distributions/__init__.py", "dbldatagen/column_generation_spec.py", "dbldatagen/column_spec_options.py", "dbldatagen/datagen_constants.py", - "dbldatagen/datarange.py", - "dbldatagen/daterange.py", - "dbldatagen/nrange.py", "dbldatagen/schema_parser.py", ] @@ -215,16 +225,14 @@ ignore = [ "tutorial", "examples", "tests", - "dbldatagen/constraints", "dbldatagen/datasets", - "dbldatagen/distributions", "dbldatagen/__init__.py", + "dbldatagen/constraints/__init__.py", + "dbldatagen/distributions/__init__.py", "dbldatagen/column_generation_spec.py", "dbldatagen/column_spec_options.py", "dbldatagen/datagen_constants.py", - "dbldatagen/datarange.py", "dbldatagen/daterange.py", - "dbldatagen/nrange.py", "dbldatagen/schema_parser.py", "dbldatagen/serialization.py", "dbldatagen/utils.py" @@ -248,7 +256,6 @@ ignore-paths = [ "tutorial", "examples", "tests", - "dbldatagen/constraints", "dbldatagen/datasets", "dbldatagen/distributions", "dbldatagen/__init__.py", @@ -256,9 +263,6 @@ ignore-paths = [ "dbldatagen/column_spec_options.py", "dbldatagen/data_generator.py", "dbldatagen/datagen_constants.py", - "dbldatagen/datarange.py", - "dbldatagen/daterange.py", - "dbldatagen/nrange.py", "dbldatagen/schema_parser.py", "dbldatagen/serialization.py", "dbldatagen/utils.py" @@ -377,16 +381,13 @@ exclude = [ "tests", "tutorial", "tests", - "dbldatagen/constraints", "dbldatagen/datasets", - "dbldatagen/distributions", "dbldatagen/__init__.py", + "dbldatagen/constraints/__init__.py", + "dbldatagen/distributions/__init__.py", "dbldatagen/column_generation_spec.py", "dbldatagen/column_spec_options.py", "dbldatagen/datagen_constants.py", - "dbldatagen/datarange.py", - "dbldatagen/daterange.py", - "dbldatagen/nrange.py", "dbldatagen/schema_parser.py", "dbldatagen/serialization.py", "dbldatagen/utils.py" diff --git a/tests/__init__.py b/tests/__init__.py index ddbf42dc..7cb94c3c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -20,17 +20,23 @@ This module defines the test data generator and supporting classes """ -__all__ = ["test_basic_test", "test_quick_tests", "test_columnGenerationSpec", - "test_pandas_integration", - "test_distributions", - "test_ranged_values_and_dates", - "test_ranges", - "test_options", - "test_topological_sort", - "test_large_schema", "test_schema_parser", "test_scripting", - "test_text_generation", "test_iltext_generation", - "test_types", - "test_utils", - "test_weights", - "test_dependent_data" - ] +__all__ = [ + "test_basic_test", + "test_quick_tests", + "test_columnGenerationSpec", + "test_pandas_integration", + "test_distributions", + "test_ranged_values_and_dates", + "test_ranges", + "test_options", + "test_topological_sort", + "test_large_schema", + "test_schema_parser", + "test_scripting", + "test_text_generation", + "test_iltext_generation", + "test_types", + "test_utils", + "test_weights", + "test_dependent_data", +] diff --git a/tests/test_basic_test.py b/tests/test_basic_test.py index 9b9ddf31..99901314 100644 --- a/tests/test_basic_test.py +++ b/tests/test_basic_test.py @@ -24,17 +24,18 @@ class TestBasicOperation: @pytest.fixture(scope="class") def testDataSpec(self, setupLogging): - retval = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=self.SMALL_ROW_COUNT, - seedMethod='hash_fieldname') - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=self.column_count) - .withColumn("code1", IntegerType(), min=100, max=200) - .withColumn("code2", IntegerType(), min=0, max=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - ) + retval = ( + dg.DataGenerator( + sparkSession=spark, name="test_data_set1", rows=self.SMALL_ROW_COUNT, seedMethod='hash_fieldname' + ) + .withIdOutput() + .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", numColumns=self.column_count) + .withColumn("code1", IntegerType(), min=100, max=200) + .withColumn("code2", IntegerType(), min=0, max=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) return retval @pytest.fixture(scope="class") @@ -42,7 +43,7 @@ def testData(self, testDataSpec): return testDataSpec.build().cache() def setup_log_capture(self, caplog_object): - """ set up log capture fixture + """set up log capture fixture Sets up log capture fixture to only capture messages after setup and only capture warnings and errors @@ -82,13 +83,14 @@ def test_default_partition_assignment(self, testDataSpec): def test_basic_data_generation(self, testData): """Test basic data generation of distinct values""" - counts = testData.agg(F.countDistinct("id").alias("id_count"), - F.countDistinct("code1").alias("code1_count"), - F.countDistinct("code2").alias("code2_count"), - F.countDistinct("code3").alias("code3_count"), - F.countDistinct("code4").alias("code4_count"), - F.countDistinct("code5").alias("code5_count") - ).collect()[0] + counts = testData.agg( + F.countDistinct("id").alias("id_count"), + F.countDistinct("code1").alias("code1_count"), + F.countDistinct("code2").alias("code2_count"), + F.countDistinct("code3").alias("code3_count"), + F.countDistinct("code4").alias("code4_count"), + F.countDistinct("code5").alias("code5_count"), + ).collect()[0] assert counts["id_count"] == self.row_count assert counts["code1_count"] == 101 @@ -101,18 +103,24 @@ def test_alt_seed_column(self, caplog): # caplog fixture captures log content self.setup_log_capture(caplog) - dgspec = (dg.DataGenerator(sparkSession=spark, name="alt_data_set", rows=10000, - partitions=4, seedMethod='hash_fieldname', verbose=True, - seedColumnName="_id") - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=4) - .withColumn("code1", IntegerType(), min=100, max=200) - .withColumn("code2", IntegerType(), min=0, max=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - ) + dgspec = ( + dg.DataGenerator( + sparkSession=spark, + name="alt_data_set", + rows=10000, + partitions=4, + seedMethod='hash_fieldname', + verbose=True, + seedColumnName="_id", + ) + .withIdOutput() + .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", numColumns=4) + .withColumn("code1", IntegerType(), min=100, max=200) + .withColumn("code2", IntegerType(), min=0, max=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) fieldsFromGenerator = set(dgspec.getOutputColumnNames()) @@ -136,18 +144,19 @@ def test_alt_seed_column(self, caplog): seed_column_warnings_and_errors = self.get_log_capture_warngings_and_errors(caplog, "seed") assert seed_column_warnings_and_errors == 0, "Should not have error messages about seed column" - @pytest.mark.parametrize("caseName, withIdOutput, idType, additionalOptions", - [("withIdOutput", True, FloatType(), {}), - ("withIdOutput multicolumn", True, FloatType(), {'numColumns': 4}), - ("with no Id output", False, FloatType(), {}), - ("with no Id output multicolumn", False, FloatType(), {'numColumns': 4}), - ("with no Id output random", - False, - IntegerType(), - {'uniqueValues': 5000, 'random': True}) - ]) - def test_seed_column_nocollision(self, caseName, withIdOutput, idType, additionalOptions, caplog): \ - # pylint: disable=too-many-positional-arguments + @pytest.mark.parametrize( + "caseName, withIdOutput, idType, additionalOptions", + [ + ("withIdOutput", True, FloatType(), {}), + ("withIdOutput multicolumn", True, FloatType(), {'numColumns': 4}), + ("with no Id output", False, FloatType(), {}), + ("with no Id output multicolumn", False, FloatType(), {'numColumns': 4}), + ("with no Id output random", False, IntegerType(), {'uniqueValues': 5000, 'random': True}), + ], + ) + def test_seed_column_nocollision( + self, caseName, withIdOutput, idType, additionalOptions, caplog + ): # pylint: disable=too-many-positional-arguments logging.info(f"case: {caseName}") @@ -155,20 +164,25 @@ def test_seed_column_nocollision(self, caseName, withIdOutput, idType, additiona self.setup_log_capture(caplog) # test that there are no collisions on the use of the 'id' field) - dgSpec = (dg.DataGenerator(sparkSession=spark, name="alt_data_set", rows=10000, - partitions=4, seedMethod='hash_fieldname', verbose=True, - seedColumnName="_id")) + dgSpec = dg.DataGenerator( + sparkSession=spark, + name="alt_data_set", + rows=10000, + partitions=4, + seedMethod='hash_fieldname', + verbose=True, + seedColumnName="_id", + ) if withIdOutput: dgSpec = dgSpec.withIdOutput() - dgSpec = (dgSpec - .withColumn("id", idType, expr="floor(rand() * 350) * (86400 + 3600)", - **additionalOptions) - .withColumn("code1", IntegerType(), min=100, max=200) - .withColumn("code2", IntegerType(), min=0, max=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - ) + dgSpec = ( + dgSpec.withColumn("id", idType, expr="floor(rand() * 350) * (86400 + 3600)", **additionalOptions) + .withColumn("code1", IntegerType(), min=100, max=200) + .withColumn("code2", IntegerType(), min=0, max=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + ) fieldsFromGenerator = set(dgSpec.getOutputColumnNames()) @@ -181,18 +195,22 @@ def test_seed_column_nocollision(self, caseName, withIdOutput, idType, additiona seed_column_warnings_and_errors = self.get_log_capture_warngings_and_errors(caplog, "seed") assert seed_column_warnings_and_errors == 0, "Should not have error messages about seed column" - @pytest.mark.parametrize("caseName, withIdOutput, idType, idName", - [("withIdOutput float", True, FloatType(), "id"), - ("withIdOutput int", True, IntegerType(), "id"), - ("with no Id output float", False, FloatType(), "id"), - ("with no Id output int", False, IntegerType(), "id"), - ("withIdOutput _id float", True, FloatType(), "_id"), - ("withIdOutput _id int", True, IntegerType(), "_id"), - ("with no Id output _id float", False, FloatType(), "_id"), - ("with no Id output _id int", False, IntegerType(), "_id"), - ]) - def test_seed_column_expected_collision1(self, caseName, withIdOutput, idType, idName, caplog): \ - # pylint: disable=too-many-positional-arguments + @pytest.mark.parametrize( + "caseName, withIdOutput, idType, idName", + [ + ("withIdOutput float", True, FloatType(), "id"), + ("withIdOutput int", True, IntegerType(), "id"), + ("with no Id output float", False, FloatType(), "id"), + ("with no Id output int", False, IntegerType(), "id"), + ("withIdOutput _id float", True, FloatType(), "_id"), + ("withIdOutput _id int", True, IntegerType(), "_id"), + ("with no Id output _id float", False, FloatType(), "_id"), + ("with no Id output _id int", False, IntegerType(), "_id"), + ], + ) + def test_seed_column_expected_collision1( + self, caseName, withIdOutput, idType, idName, caplog + ): # pylint: disable=too-many-positional-arguments logging.info(f"case: {caseName}") @@ -200,26 +218,34 @@ def test_seed_column_expected_collision1(self, caseName, withIdOutput, idType, i self.setup_log_capture(caplog) if idName == "id": - dgSpec = (dg.DataGenerator(sparkSession=spark, name="alt_data_set", rows=10000, - partitions=4, seedMethod='hash_fieldname', verbose=True) - ) + dgSpec = dg.DataGenerator( + sparkSession=spark, + name="alt_data_set", + rows=10000, + partitions=4, + seedMethod='hash_fieldname', + verbose=True, + ) else: - dgSpec = (dg.DataGenerator(sparkSession=spark, name="alt_data_set", rows=10000, - partitions=4, seedMethod='hash_fieldname', verbose=True, - seedColumnName=idName - ) - ) + dgSpec = dg.DataGenerator( + sparkSession=spark, + name="alt_data_set", + rows=10000, + partitions=4, + seedMethod='hash_fieldname', + verbose=True, + seedColumnName=idName, + ) if withIdOutput: dgSpec = dgSpec.withIdOutput() - dgSpec = (dgSpec - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=4) - .withColumn(idName, idType, min=100, max=200) - .withColumn("code2", IntegerType(), min=0, max=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - ) + dgSpec = ( + dgSpec.withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", numColumns=4) + .withColumn(idName, idType, min=100, max=200) + .withColumn("code2", IntegerType(), min=0, max=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + ) fieldsFromGenerator = set(dgSpec.getOutputColumnNames()) @@ -244,9 +270,11 @@ def test_clone(self, testDataSpec): """Test clone method""" ds_copy1 = testDataSpec.clone() - df_copy1 = (ds_copy1.withRowCount(1000) - .withColumn("another_column", StringType(), values=['a', 'b', 'c'], random=True) - .build()) + df_copy1 = ( + ds_copy1.withRowCount(1000) + .withColumn("another_column", StringType(), values=['a', 'b', 'c'], random=True) + .build() + ) assert df_copy1.count() == 1000 fields1 = ds_copy1.getOutputColumnNames() @@ -263,11 +291,13 @@ def test_multiple_base_columns(self, testDataSpec): """Test data generation with multiple base columns""" ds_copy1 = testDataSpec.clone() - df_copy1 = (ds_copy1.withRowCount(self.TINY_ROW_COUNT) - .withColumn("ac1", IntegerType(), baseColumn=['code1', 'code2'], minValue=100, maxValue=200) - .withColumn("ac2", IntegerType(), baseColumn=['code1', 'code2'], - minValue=100, maxValue=200, random=True) - .build().cache()) + df_copy1 = ( + ds_copy1.withRowCount(self.TINY_ROW_COUNT) + .withColumn("ac1", IntegerType(), baseColumn=['code1', 'code2'], minValue=100, maxValue=200) + .withColumn("ac2", IntegerType(), baseColumn=['code1', 'code2'], minValue=100, maxValue=200, random=True) + .build() + .cache() + ) assert df_copy1.count() == 1000 df_overlimit = df_copy1.where("ac1 > 200") @@ -288,48 +318,51 @@ def test_repeatable_multiple_base_columns(self, testDataSpec): """ ds_copy1 = testDataSpec.clone() - df_copy1 = (ds_copy1.withRowCount(1000) - .withColumn("ac1", IntegerType(), baseColumn=['code1', 'code2'], minValue=100, maxValue=200) - .withColumn("ac2", IntegerType(), baseColumn=['code1', 'code2'], - minValue=100, maxValue=200, random=True) - .build()) + df_copy1 = ( + ds_copy1.withRowCount(1000) + .withColumn("ac1", IntegerType(), baseColumn=['code1', 'code2'], minValue=100, maxValue=200) + .withColumn("ac2", IntegerType(), baseColumn=['code1', 'code2'], minValue=100, maxValue=200, random=True) + .build() + ) assert df_copy1.count() == 1000 df_copy1.createOrReplaceTempView("test_data") # check that for each combination of code1 and code2, we only have a single value of ac1 - df_check = spark.sql("""select * from (select count(ac1) as count_ac1 + df_check = spark.sql( + """select * from (select count(ac1) as count_ac1 from (select distinct ac1, code1, code2 from test_data) group by code1, code2) where count_ac1 < 1 or count_ac1 > 1 - """) + """ + ) assert df_check.count() == 0 def test_default_spark_instance(self): - """ Test different types of seeding for random values""" - ds1 = (dg.DataGenerator(name="test_data_set1", rows=1000, seedMethod='hash_fieldname') - .withIdOutput() - .withColumn("code2", IntegerType(), minValue=0, maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - - ) + """Test different types of seeding for random values""" + ds1 = ( + dg.DataGenerator(name="test_data_set1", rows=1000, seedMethod='hash_fieldname') + .withIdOutput() + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) df = ds1.build() assert df.count() == 1000 def test_default_spark_instance2(self): - """ Test different types of seeding for random values""" - ds1 = (dg.DataGenerator(name="test_data_set1", rows=1000, seedMethod='hash_fieldname') - .withIdOutput() - .withColumn("code2", IntegerType(), minValue=0, maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - - ) + """Test different types of seeding for random values""" + ds1 = ( + dg.DataGenerator(name="test_data_set1", rows=1000, seedMethod='hash_fieldname') + .withIdOutput() + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) ds1._setupSparkSession(None) @@ -337,15 +370,15 @@ def test_default_spark_instance2(self): assert sparkSession is not None def test_multiple_hash_methods(self): - """ Test different types of seeding for random values""" - ds1 = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, seedMethod='hash_fieldname') - .withIdOutput() - .withColumn("code2", IntegerType(), minValue=0, maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - - ) + """Test different types of seeding for random values""" + ds1 = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, seedMethod='hash_fieldname') + .withIdOutput() + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) df = ds1.build() assert df.count() == 1000 @@ -357,14 +390,14 @@ def test_multiple_hash_methods(self): assert df_count_values2.count() == 0 df_count_values3 = df.where("code5 not in ('a', 'b', 'c')") assert df_count_values3.count() == 0 - ds2 = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, seedMethod='fixed') - .withIdOutput() - .withColumn("code2", IntegerType(), minValue=0, maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - - ) + ds2 = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, seedMethod='fixed') + .withIdOutput() + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) df2 = ds2.build() assert df2.count() == 1000 @@ -376,14 +409,14 @@ def test_multiple_hash_methods(self): assert df2_count_values2.count() == 0 df2_count_values3 = df2.where("code5 not in ('a', 'b', 'c')") assert df2_count_values3.count() == 0 - ds3 = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, seedMethod=None) - .withIdOutput() - .withColumn("code2", IntegerType(), minValue=0, maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - - ) + ds3 = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, seedMethod=None) + .withIdOutput() + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) df3 = ds3.build() assert df3.count() == 1000 @@ -399,12 +432,12 @@ def test_multiple_hash_methods(self): assert df3_count_values3.count() == 1000 def test_generated_data_count(self, testData): - """ Test that rows are generated for the number of rows indicated by the row count""" + """Test that rows are generated for the number of rows indicated by the row count""" count = testData.count() assert count == self.row_count def test_distinct_count(self, testData): - """ Test that ids are unique""" + """Test that ids are unique""" distinct_count = testData.select('id').distinct().count() assert distinct_count == self.row_count @@ -415,14 +448,22 @@ def test_column_count(self, testData): def test_values_code1(self, testData): """Test values""" - values = testData.select('code1').groupBy().agg(F.min('code1').alias('minValue'), - F.max('code1').alias('maxValue')).collect()[0] + values = ( + testData.select('code1') + .groupBy() + .agg(F.min('code1').alias('minValue'), F.max('code1').alias('maxValue')) + .collect()[0] + ) assert {100, 200} == {values.minValue, values.maxValue} def test_values_code2(self, testData): """Test values""" - values = testData.select('code2').groupBy().agg(F.min('code2').alias('minValue'), - F.max('code2').alias('maxValue')).collect()[0] + values = ( + testData.select('code2') + .groupBy() + .agg(F.min('code2').alias('minValue'), F.max('code2').alias('maxValue')) + .collect()[0] + ) assert {0, 10} == {values.minValue, values.maxValue} def test_values_code3(self, testData): @@ -460,16 +501,17 @@ def test_basic_adhoc(self, testDataSpec, testData): def test_basic_with_schema(self, testDataSpec): """Test use of schema""" - schema = StructType([ - StructField("region_id", IntegerType(), True), - StructField("region_cd", StringType(), True), - StructField("c", StringType(), True), - StructField("c1", StringType(), True), - StructField("state1", StringType(), True), - StructField("state2", StringType(), True), - StructField("st_desc", StringType(), True), - - ]) + schema = StructType( + [ + StructField("region_id", IntegerType(), True), + StructField("region_cd", StringType(), True), + StructField("c", StringType(), True), + StructField("c1", StringType(), True), + StructField("state1", StringType(), True), + StructField("state2", StringType(), True), + StructField("st_desc", StringType(), True), + ] + ) testDataSpec2 = testDataSpec.clone() print("data generation description:", testDataSpec2.describe()) @@ -477,9 +519,7 @@ def test_basic_with_schema(self, testDataSpec): print("data generation str:", str(testDataSpec2)) testDataSpec2.explain() - testDataSpec3 = (testDataSpec2.withSchema(schema) - .withColumnSpec("state1", values=['ca', 'wa', 'ny']) - ) + testDataSpec3 = testDataSpec2.withSchema(schema).withColumnSpec("state1", values=['ca', 'wa', 'ny']) print("output columns", testDataSpec3.getOutputColumnNames()) @@ -502,7 +542,8 @@ def test_partitions(self): .withColumn("code2", IntegerType(), maxValue=1000, step=5) .withColumn("code3", IntegerType(), minValue=100, maxValue=200, step=1, random=True) .withColumn("xcode", StringType(), values=["a", "test", "value"], random=True) - .withColumn("rating", FloatType(), minValue=1.0, maxValue=5.0, step=0.00001, random=True)) + .withColumn("rating", FloatType(), minValue=1.0, maxValue=5.0, step=0.00001, random=True) + ) df = testdata_defn.build() df.printSchema() @@ -515,9 +556,8 @@ def test_partitions(self): def test_percent_nulls(self): rows_wanted = 20000 - testdata_defn = ( - dg.DataGenerator(name="basic_dataset", rows=rows_wanted) - .withColumn("code1", IntegerType(), minValue=1, maxValue=20, step=1, percent_nulls=0.1) + testdata_defn = dg.DataGenerator(name="basic_dataset", rows=rows_wanted).withColumn( + "code1", IntegerType(), minValue=1, maxValue=20, step=1, percent_nulls=0.1 ) df = testdata_defn.build() diff --git a/tests/test_build_planning.py b/tests/test_build_planning.py index d52e2b89..efb22aca 100644 --- a/tests/test_build_planning.py +++ b/tests/test_build_planning.py @@ -6,141 +6,143 @@ import dbldatagen as dg -schema = StructType([ - StructField("PK1", StringType(), True), - StructField("XYYZ_IDS", StringType(), True), - StructField("R_ID", IntegerType(), True), - StructField("CL_ID", StringType(), True), - StructField("INGEST_DATE", TimestampType(), True), - StructField("CMPY_ID", DecimalType(38, 0), True), - StructField("TXN_ID", DecimalType(38, 0), True), - StructField("SEQUENCE_NUMBER", DecimalType(38, 0), True), - StructField("DETAIL_ORDER", DecimalType(38, 0), True), - StructField("TX_T_ID", DecimalType(38, 0), True), - StructField("TXN_DATE", TimestampType(), True), - StructField("AN_ID", DecimalType(38, 0), True), - StructField("ANC_ID", DecimalType(38, 0), True), - StructField("ANV_ID", DecimalType(38, 0), True), - StructField("ANE_ID", DecimalType(38, 0), True), - StructField("AND_ID", DecimalType(38, 0), True), - StructField("APM_ID", DecimalType(38, 0), True), - StructField("ACL_ID", DecimalType(38, 0), True), - StructField("MEMO_TEXT", StringType(), True), - StructField("ITEM_ID", DecimalType(38, 0), True), - StructField("ITEM2_ID", DecimalType(38, 0), True), - StructField("V1_BASE", DecimalType(38, 9), True), - StructField("V1_YTD_AMT", DecimalType(38, 9), True), - StructField("V1_YTD_HOURS", DecimalType(38, 0), True), - StructField("ISTT", DecimalType(38, 9), True), - StructField("XXX_AMT", StringType(), True), - StructField("XXX_BASE", StringType(), True), - StructField("XXX_ISTT", StringType(), True), - StructField("HOURS", DecimalType(38, 0), True), - StructField("STATE", DecimalType(38, 0), True), - StructField("LSTATE", DecimalType(38, 0), True), - StructField("XXX_JURISDICTION_ID", DecimalType(38, 0), True), - StructField("XXY_JURISDICTION_ID", DecimalType(38, 0), True), - StructField("AS_OF_DATE", TimestampType(), True), - StructField("IS_PAYOUT", StringType(), True), - StructField("IS_PYRL_LIABILITY", StringType(), True), - StructField("IS_PYRL_SUMMARY", StringType(), True), - StructField("PYRL_LIABILITY_DATE", TimestampType(), True), - StructField("PYRL_LIAB_BEGIN_DATE", TimestampType(), True), - StructField("QTY", DecimalType(38, 9), True), - StructField("RATE", DecimalType(38, 9), True), - StructField("AMOUNT", DecimalType(38, 9), True), - StructField("SPERCENT", DecimalType(38, 9), True), - StructField("DOC_XREF", StringType(), True), - StructField("IS_A", StringType(), True), - StructField("IS_S", StringType(), True), - StructField("IS_CP", StringType(), True), - StructField("IS_VP", StringType(), True), - StructField("IS_B", StringType(), True), - StructField("IS_EX", StringType(), True), - StructField("IS_I", StringType(), True), - StructField("IS_CL", StringType(), True), - StructField("IS_DPD", StringType(), True), - StructField("IS_DPD2", StringType(), True), - StructField("DPD_ID", DecimalType(38, 0), True), - StructField("IS_NP", StringType(), True), - StructField("TAXABLE_TYPE", DecimalType(38, 0), True), - StructField("IS_ARP", StringType(), True), - StructField("IS_APP", StringType(), True), - StructField("BALANCE1", DecimalType(38, 9), True), - StructField("BALANCE2", DecimalType(38, 9), True), - StructField("IS_FLAG1", StringType(), True), - StructField("IS_FLAG2", StringType(), True), - StructField("STATEMENT_ID", DecimalType(38, 0), True), - StructField("INVOICE_ID", DecimalType(38, 0), True), - StructField("STATEMENT_DATE", TimestampType(), True), - StructField("INVOICE_DATE", TimestampType(), True), - StructField("DUE_DATE", TimestampType(), True), - StructField("EXAMPLE1_ID", DecimalType(38, 0), True), - StructField("EXAMPLE2_ID", DecimalType(38, 0), True), - StructField("IS_FLAG3", StringType(), True), - StructField("ANOTHER_ID", DecimalType(38, 0), True), - StructField("MARKUP", DecimalType(38, 9), True), - StructField("S_DATE", TimestampType(), True), - StructField("SD_TYPE", DecimalType(38, 0), True), - StructField("SOURCE_TXN_ID", DecimalType(38, 0), True), - StructField("SOURCE_TXN_SEQUENCE", DecimalType(38, 0), True), - StructField("PAID_DATE", TimestampType(), True), - StructField("OFX_TXN_ID", DecimalType(38, 0), True), - StructField("OFX_MATCH_FLAG", DecimalType(38, 0), True), - StructField("OLB_MATCH_MODE", DecimalType(38, 0), True), - StructField("OLB_MATCH_AMOUNT", DecimalType(38, 9), True), - StructField("OLB_RULE_ID", DecimalType(38, 0), True), - StructField("ETMMODE", DecimalType(38, 0), True), - StructField("DDA_ID", DecimalType(38, 0), True), - StructField("DDL_STATUS", DecimalType(38, 0), True), - StructField("ICFS", DecimalType(38, 0), True), - StructField("CREATE_DATE", TimestampType(), True), - StructField("CREATE_USER_ID", DecimalType(38, 0), True), - StructField("LAST_MODIFY_DATE", TimestampType(), True), - StructField("LAST_MODIFY_USER_ID", DecimalType(38, 0), True), - StructField("EDIT_SEQUENCE", DecimalType(38, 0), True), - StructField("ADDED_AUDIT_ID", DecimalType(38, 0), True), - StructField("AUDIT_ID", DecimalType(38, 0), True), - StructField("AUDIT_FLAG", StringType(), True), - StructField("EXCEPTION_FLAG", StringType(), True), - StructField("IS_PENALTY", StringType(), True), - StructField("IS_INTEREST", StringType(), True), - StructField("NET_AMOUNT", DecimalType(38, 9), True), - StructField("TAX_AMOUNT", DecimalType(38, 9), True), - StructField("TAX_CODE_ID", DecimalType(38, 0), True), - StructField("TAX_RATE_ID", DecimalType(38, 0), True), - StructField("CURRENCY_TYPE", DecimalType(38, 0), True), - StructField("EXCHANGE_RATE", DecimalType(38, 9), True), - StructField("HA", DecimalType(38, 9), True), - StructField("HO_AMT", DecimalType(38, 9), True), - StructField("IS_FGL", StringType(), True), - StructField("ST_TYPE", DecimalType(38, 0), True), - StructField("STO_BALANCE", DecimalType(38, 9), True), - StructField("TO_AMT", DecimalType(38, 9), True), - StructField("INC_AMOUNT", DecimalType(38, 9), True), - StructField("CA_TAX_AMT", DecimalType(38, 9), True), - StructField("HGS_CODE_ID", DecimalType(38, 0), True), - StructField("DISC_ID", DecimalType(38, 0), True), - StructField("DISC_AMT", DecimalType(38, 9), True), - StructField("TXN_DISCOUNT_AMOUNT", DecimalType(38, 9), True), - StructField("SUBTOTAL_AMOUNT", DecimalType(38, 9), True), - StructField("LINE_DETAIL_TYPE", DecimalType(38, 0), True), - StructField("W_RATE_ID", DecimalType(38, 0), True), - StructField("R_QTY", DecimalType(38, 9), True), - StructField("R_AMOUNT", DecimalType(38, 9), True), - StructField("AMT_2", DecimalType(38, 9), True), - StructField("AMT_3", DecimalType(38, 9), True), - StructField("FLAG_5", StringType(), True), - StructField("CUSTOM_FIELD_VALUES", StringType(), True), - StructField("PTT", DecimalType(38, 0), True), - StructField("IRT", DecimalType(38, 0), True), - StructField("CUSTOM_FIELD_VALS", StringType(), True), - StructField("RCC", StringType(), True), - StructField("LAST_MODIFIED_UTC", TimestampType(), True), - StructField("date", DateType(), True), - StructField("yearMonth", StringType(), True), - StructField("isDeleted", BooleanType(), True) -]) +schema = StructType( + [ + StructField("PK1", StringType(), True), + StructField("XYYZ_IDS", StringType(), True), + StructField("R_ID", IntegerType(), True), + StructField("CL_ID", StringType(), True), + StructField("INGEST_DATE", TimestampType(), True), + StructField("CMPY_ID", DecimalType(38, 0), True), + StructField("TXN_ID", DecimalType(38, 0), True), + StructField("SEQUENCE_NUMBER", DecimalType(38, 0), True), + StructField("DETAIL_ORDER", DecimalType(38, 0), True), + StructField("TX_T_ID", DecimalType(38, 0), True), + StructField("TXN_DATE", TimestampType(), True), + StructField("AN_ID", DecimalType(38, 0), True), + StructField("ANC_ID", DecimalType(38, 0), True), + StructField("ANV_ID", DecimalType(38, 0), True), + StructField("ANE_ID", DecimalType(38, 0), True), + StructField("AND_ID", DecimalType(38, 0), True), + StructField("APM_ID", DecimalType(38, 0), True), + StructField("ACL_ID", DecimalType(38, 0), True), + StructField("MEMO_TEXT", StringType(), True), + StructField("ITEM_ID", DecimalType(38, 0), True), + StructField("ITEM2_ID", DecimalType(38, 0), True), + StructField("V1_BASE", DecimalType(38, 9), True), + StructField("V1_YTD_AMT", DecimalType(38, 9), True), + StructField("V1_YTD_HOURS", DecimalType(38, 0), True), + StructField("ISTT", DecimalType(38, 9), True), + StructField("XXX_AMT", StringType(), True), + StructField("XXX_BASE", StringType(), True), + StructField("XXX_ISTT", StringType(), True), + StructField("HOURS", DecimalType(38, 0), True), + StructField("STATE", DecimalType(38, 0), True), + StructField("LSTATE", DecimalType(38, 0), True), + StructField("XXX_JURISDICTION_ID", DecimalType(38, 0), True), + StructField("XXY_JURISDICTION_ID", DecimalType(38, 0), True), + StructField("AS_OF_DATE", TimestampType(), True), + StructField("IS_PAYOUT", StringType(), True), + StructField("IS_PYRL_LIABILITY", StringType(), True), + StructField("IS_PYRL_SUMMARY", StringType(), True), + StructField("PYRL_LIABILITY_DATE", TimestampType(), True), + StructField("PYRL_LIAB_BEGIN_DATE", TimestampType(), True), + StructField("QTY", DecimalType(38, 9), True), + StructField("RATE", DecimalType(38, 9), True), + StructField("AMOUNT", DecimalType(38, 9), True), + StructField("SPERCENT", DecimalType(38, 9), True), + StructField("DOC_XREF", StringType(), True), + StructField("IS_A", StringType(), True), + StructField("IS_S", StringType(), True), + StructField("IS_CP", StringType(), True), + StructField("IS_VP", StringType(), True), + StructField("IS_B", StringType(), True), + StructField("IS_EX", StringType(), True), + StructField("IS_I", StringType(), True), + StructField("IS_CL", StringType(), True), + StructField("IS_DPD", StringType(), True), + StructField("IS_DPD2", StringType(), True), + StructField("DPD_ID", DecimalType(38, 0), True), + StructField("IS_NP", StringType(), True), + StructField("TAXABLE_TYPE", DecimalType(38, 0), True), + StructField("IS_ARP", StringType(), True), + StructField("IS_APP", StringType(), True), + StructField("BALANCE1", DecimalType(38, 9), True), + StructField("BALANCE2", DecimalType(38, 9), True), + StructField("IS_FLAG1", StringType(), True), + StructField("IS_FLAG2", StringType(), True), + StructField("STATEMENT_ID", DecimalType(38, 0), True), + StructField("INVOICE_ID", DecimalType(38, 0), True), + StructField("STATEMENT_DATE", TimestampType(), True), + StructField("INVOICE_DATE", TimestampType(), True), + StructField("DUE_DATE", TimestampType(), True), + StructField("EXAMPLE1_ID", DecimalType(38, 0), True), + StructField("EXAMPLE2_ID", DecimalType(38, 0), True), + StructField("IS_FLAG3", StringType(), True), + StructField("ANOTHER_ID", DecimalType(38, 0), True), + StructField("MARKUP", DecimalType(38, 9), True), + StructField("S_DATE", TimestampType(), True), + StructField("SD_TYPE", DecimalType(38, 0), True), + StructField("SOURCE_TXN_ID", DecimalType(38, 0), True), + StructField("SOURCE_TXN_SEQUENCE", DecimalType(38, 0), True), + StructField("PAID_DATE", TimestampType(), True), + StructField("OFX_TXN_ID", DecimalType(38, 0), True), + StructField("OFX_MATCH_FLAG", DecimalType(38, 0), True), + StructField("OLB_MATCH_MODE", DecimalType(38, 0), True), + StructField("OLB_MATCH_AMOUNT", DecimalType(38, 9), True), + StructField("OLB_RULE_ID", DecimalType(38, 0), True), + StructField("ETMMODE", DecimalType(38, 0), True), + StructField("DDA_ID", DecimalType(38, 0), True), + StructField("DDL_STATUS", DecimalType(38, 0), True), + StructField("ICFS", DecimalType(38, 0), True), + StructField("CREATE_DATE", TimestampType(), True), + StructField("CREATE_USER_ID", DecimalType(38, 0), True), + StructField("LAST_MODIFY_DATE", TimestampType(), True), + StructField("LAST_MODIFY_USER_ID", DecimalType(38, 0), True), + StructField("EDIT_SEQUENCE", DecimalType(38, 0), True), + StructField("ADDED_AUDIT_ID", DecimalType(38, 0), True), + StructField("AUDIT_ID", DecimalType(38, 0), True), + StructField("AUDIT_FLAG", StringType(), True), + StructField("EXCEPTION_FLAG", StringType(), True), + StructField("IS_PENALTY", StringType(), True), + StructField("IS_INTEREST", StringType(), True), + StructField("NET_AMOUNT", DecimalType(38, 9), True), + StructField("TAX_AMOUNT", DecimalType(38, 9), True), + StructField("TAX_CODE_ID", DecimalType(38, 0), True), + StructField("TAX_RATE_ID", DecimalType(38, 0), True), + StructField("CURRENCY_TYPE", DecimalType(38, 0), True), + StructField("EXCHANGE_RATE", DecimalType(38, 9), True), + StructField("HA", DecimalType(38, 9), True), + StructField("HO_AMT", DecimalType(38, 9), True), + StructField("IS_FGL", StringType(), True), + StructField("ST_TYPE", DecimalType(38, 0), True), + StructField("STO_BALANCE", DecimalType(38, 9), True), + StructField("TO_AMT", DecimalType(38, 9), True), + StructField("INC_AMOUNT", DecimalType(38, 9), True), + StructField("CA_TAX_AMT", DecimalType(38, 9), True), + StructField("HGS_CODE_ID", DecimalType(38, 0), True), + StructField("DISC_ID", DecimalType(38, 0), True), + StructField("DISC_AMT", DecimalType(38, 9), True), + StructField("TXN_DISCOUNT_AMOUNT", DecimalType(38, 9), True), + StructField("SUBTOTAL_AMOUNT", DecimalType(38, 9), True), + StructField("LINE_DETAIL_TYPE", DecimalType(38, 0), True), + StructField("W_RATE_ID", DecimalType(38, 0), True), + StructField("R_QTY", DecimalType(38, 9), True), + StructField("R_AMOUNT", DecimalType(38, 9), True), + StructField("AMT_2", DecimalType(38, 9), True), + StructField("AMT_3", DecimalType(38, 9), True), + StructField("FLAG_5", StringType(), True), + StructField("CUSTOM_FIELD_VALUES", StringType(), True), + StructField("PTT", DecimalType(38, 0), True), + StructField("IRT", DecimalType(38, 0), True), + StructField("CUSTOM_FIELD_VALS", StringType(), True), + StructField("RCC", StringType(), True), + StructField("LAST_MODIFIED_UTC", TimestampType(), True), + StructField("date", DateType(), True), + StructField("yearMonth", StringType(), True), + StructField("isDeleted", BooleanType(), True), + ] +) spark = dg.SparkSingleton.getLocalInstance("unit tests") @@ -162,24 +164,23 @@ def sampleDataSpec(self): sale_values = ['RETAIL', 'ONLINE', 'WHOLESALE', 'RETURN'] sale_weights = [1, 5, 5, 1] - testDataspec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=self.row_count, partitions=4) - .withSchema(schema) - .withIdOutput() - .withColumnSpecs(patterns=".*_ID", match_types=StringType(), format="%010d", - minValue=1, maxValue=123, - step=1) - .withColumnSpecs(patterns=".*_IDS", match_types="string", format="%010d", minValue=1, - maxValue=100, step=1) - # .withColumnSpec("R3D3_CLUSTER_IDS", minValue=1, maxValue=100, step=1) - .withColumnSpec("XYYZ_IDS", minValue=1, maxValue=123, step=1, - format="%05d") - # .withColumnSpec("nstr4", percentNulls=0.1, - # minValue=1, maxValue=9, step=2, format="%04d") - # example of IS_SALE - .withColumnSpec("IS_S", values=sale_values, weights=sale_weights, random=True) - # .withColumnSpec("nstr4", percentNulls=0.1, - # minValue=1, maxValue=9, step=2, format="%04d") - ) + testDataspec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=self.row_count, partitions=4) + .withSchema(schema) + .withIdOutput() + .withColumnSpecs( + patterns=".*_ID", match_types=StringType(), format="%010d", minValue=1, maxValue=123, step=1 + ) + .withColumnSpecs(patterns=".*_IDS", match_types="string", format="%010d", minValue=1, maxValue=100, step=1) + # .withColumnSpec("R3D3_CLUSTER_IDS", minValue=1, maxValue=100, step=1) + .withColumnSpec("XYYZ_IDS", minValue=1, maxValue=123, step=1, format="%05d") + # .withColumnSpec("nstr4", percentNulls=0.1, + # minValue=1, maxValue=9, step=2, format="%04d") + # example of IS_SALE + .withColumnSpec("IS_S", values=sale_values, weights=sale_weights, random=True) + # .withColumnSpec("nstr4", percentNulls=0.1, + # minValue=1, maxValue=9, step=2, format="%04d") + ) return testDataspec @@ -194,7 +195,7 @@ def sampleDataSet(self, sampleDataSpec): return df def setup_log_capture(self, caplog_object): - """ set up log capture fixture + """set up log capture fixture Sets up log capture fixture to only capture messages after setup and only capture warnings and errors @@ -252,7 +253,7 @@ def test_build_ordering_basic(self, sampleDataSpec): assert isinstance(el, list) def builtBefore(self, field1, field2, build_order): - """ check if field1 is built before field2""" + """check if field1 is built before field2""" fieldsBuilt = [] @@ -265,7 +266,7 @@ def builtBefore(self, field1, field2, build_order): return False def builtInSeparatePhase(self, field1, field2, build_order): - """ check if field1 is built in separate phase to field2""" + """check if field1 is built in separate phase to field2""" fieldsBuilt = [] @@ -278,20 +279,28 @@ def builtInSeparatePhase(self, field1, field2, build_order): return False def test_build_ordering_explicit_dependency(self): - gen1 = dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, - seedColumnName="_id") \ - .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) \ - .withColumn("city_name", "string", template=r"\w", random=True, omit=True) \ - .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city", "struct", - expr="named_struct('name', city_name, 'id', city_id, 'population', city_pop)", - baseColumns=["city2"]) \ - .withColumn("city2", "struct", - expr="named_struct('name', city_name, 'id', city_id, 'population', city_pop)", - baseColumns=["city_pop"]) \ - .withColumn("city_id2", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True, - baseColumn="city_id") + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, seedColumnName="_id") + .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) + .withColumn("city_name", "string", template=r"\w", random=True, omit=True) + .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn( + "city", + "struct", + expr="named_struct('name', city_name, 'id', city_id, 'population', city_pop)", + baseColumns=["city2"], + ) + .withColumn( + "city2", + "struct", + expr="named_struct('name', city_name, 'id', city_id, 'population', city_pop)", + baseColumns=["city_pop"], + ) + .withColumn( + "city_id2", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True, baseColumn="city_id" + ) + ) build_order = gen1.build_order logging.info(f"Build order {build_order}") @@ -310,20 +319,23 @@ def test_build_ordering_explicit_dependency(self): assert self.builtInSeparatePhase("city", "city_pop", build_order) def test_build_ordering_explicit_dependency2(self): - gen1 = dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, - seedColumnName="_id") \ - .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) \ - .withColumn("city_name", "string", template=r"\w", random=True, omit=True) \ - .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city", "struct", - expr="named_struct('name', city_name, 'id', city_id, 'population', city_pop)", - baseColumns=["city_name", "city_id", "city_pop"]) \ - .withColumn("city2", "struct", - expr="city", - baseColumns=["city"]) \ - .withColumn("city_id2", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True, - baseColumn="city_id") + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, seedColumnName="_id") + .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) + .withColumn("city_name", "string", template=r"\w", random=True, omit=True) + .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn( + "city", + "struct", + expr="named_struct('name', city_name, 'id', city_id, 'population', city_pop)", + baseColumns=["city_name", "city_id", "city_pop"], + ) + .withColumn("city2", "struct", expr="city", baseColumns=["city"]) + .withColumn( + "city_id2", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True, baseColumn="city_id" + ) + ) build_order = gen1.build_order logging.info(f"Build order {build_order}") @@ -336,14 +348,18 @@ def test_build_ordering_explicit_dependency2(self): assert self.builtInSeparatePhase("city", "city_pop", build_order) def test_build_ordering_implicit_dependency(self): - gen1 = dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, - seedColumnName="_id") \ - .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) \ - .withColumn("city_name", "string", template=r"\w", random=True, omit=True) \ - .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city", "struct", - expr="named_struct('name', city_name, 'id', city_id, 'population', city_pop)") + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, seedColumnName="_id") + .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) + .withColumn("city_name", "string", template=r"\w", random=True, omit=True) + .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn( + "city", + "struct", + expr="named_struct('name', city_name, 'id', city_id, 'population', city_pop)", + ) + ) build_order = gen1.build_order logging.info(f"Build order {build_order}") @@ -377,33 +393,18 @@ def test_build_ordering_implicit_dependency2(self): .withColumn("area", "string", values=AREAS, random=True, omit=True) .withColumn("line", "string", values=LINES, random=True, omit=True) .withColumn("local_device_id", "int", maxValue=NUM_LOCAL_DEVICES - 1, omit=True, random=True) - .withColumn("local_device", "string", prefix="device", baseColumn="local_device_id") - - .withColumn("device_key", "string", - expr="concat('/', site, '/', area, '/', line, '/', local_device)") - + .withColumn("device_key", "string", expr="concat('/', site, '/', area, '/', line, '/', local_device)") # used to compute the device id - .withColumn("internal_device_key", "long", expr="hash(site, area, line, local_device)", - omit=True) - - .withColumn("deviceId", "string", format="0x%013x", - baseColumn="internal_device_key") - + .withColumn("internal_device_key", "long", expr="hash(site, area, line, local_device)", omit=True) + .withColumn("deviceId", "string", format="0x%013x", baseColumn="internal_device_key") # tag name is name of device signal .withColumn("tagName", "string", values=TAGS, random=True) - # tag value is state - .withColumn("tagValue", "string", - values=DEVICE_STATES, weights=DEVICE_WEIGHTS, - random=True) - - .withColumn("tag_ts", "timestamp", - begin=STARTING_DATETIME, - end=END_DATETIME, - interval=EVENT_INTERVAL, - random=True) - + .withColumn("tagValue", "string", values=DEVICE_STATES, weights=DEVICE_WEIGHTS, random=True) + .withColumn( + "tag_ts", "timestamp", begin=STARTING_DATETIME, end=END_DATETIME, interval=EVENT_INTERVAL, random=True + ) .withColumn("event_date", "date", expr="to_date(tag_ts)") ) @@ -442,16 +443,19 @@ def test_implicit_dependency3(self): dataspec = ( dg.DataGenerator(spark, rows=1000, partitions=4) .withColumn("name", percentNulls=0.01, template=r'\\w \\w|\\w a. \\w') - .withColumn("payment_instrument_type", values=['cash', 'cc', 'app'], - random=True) - .withColumn("int_payment_instrument", "int", minValue=0000, maxValue=9999, - baseColumn="name", - baseColumnType="hash", omit=True) - .withColumn("payment_instrument", - expr="format_number(int_payment_instrument, '**** ****** *####')") + .withColumn("payment_instrument_type", values=['cash', 'cc', 'app'], random=True) + .withColumn( + "int_payment_instrument", + "int", + minValue=0000, + maxValue=9999, + baseColumn="name", + baseColumnType="hash", + omit=True, + ) + .withColumn("payment_instrument", expr="format_number(int_payment_instrument, '**** ****** *####')") .withColumn("email", template=r'\\w.\\w@\\w.com') - .withColumn("md5_payment_instrument", - expr="md5(concat(payment_instrument_type, ':', payment_instrument))") + .withColumn("md5_payment_instrument", expr="md5(concat(payment_instrument_type, ':', payment_instrument))") ) build_order = dataspec.build_order @@ -467,14 +471,14 @@ def test_implicit_dependency3(self): def test_expr_attribute(self): sql_expr = "named_struct('name', city_name, 'id', city_id, 'population', city_pop)" - gen1 = dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, - seedColumnName="_id") \ - .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) \ - .withColumn("city_name", "string", template=r"\w", random=True, omit=True) \ - .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city", "struct", - expr=sql_expr) + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, seedColumnName="_id") + .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) + .withColumn("city_name", "string", template=r"\w", random=True, omit=True) + .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn("city", "struct", expr=sql_expr) + ) columnSpec = gen1.getColumnSpec("city") @@ -482,31 +486,35 @@ def test_expr_attribute(self): def test_expr_identifier_with_spaces(self): sql_expr = "named_struct('name', city_name, 'id', city_id, 'population', city_pop)" - gen1 = dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, - seedColumnName="_id") \ - .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) \ - .withColumn("city_name", "string", template=r"\w", random=True, omit=True) \ - .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city 2", "struct", - expr=sql_expr) + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, seedColumnName="_id") + .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) + .withColumn("city_name", "string", template=r"\w", random=True, omit=True) + .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn("city 2", "struct", expr=sql_expr) + ) columnSpec = gen1.getColumnSpec("city 2") assert columnSpec.expr == sql_expr def test_build_ordering_duplicate_names1(self): - gen1 = dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, - seedColumnName="_id") \ - .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) \ - .withColumn("city_name", "long", minValue=1000000, uniqueValues=10000, random=True) \ - .withColumn("city_name", "string", template=r"\w", random=True, omit=True) \ - .withColumn("extra_field", "long", minValue=1000000, uniqueValues=10000, random=True) \ - .withColumn("extra_field", "string", template=r"\w", random=True) \ - .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city", "struct", - expr="named_struct('name', city_name, 'id', city_id, 'population', city_pop)") + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, seedColumnName="_id") + .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) + .withColumn("city_name", "long", minValue=1000000, uniqueValues=10000, random=True) + .withColumn("city_name", "string", template=r"\w", random=True, omit=True) + .withColumn("extra_field", "long", minValue=1000000, uniqueValues=10000, random=True) + .withColumn("extra_field", "string", template=r"\w", random=True) + .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn( + "city", + "struct", + expr="named_struct('name', city_name, 'id', city_id, 'population', city_pop)", + ) + ) logging.info(f"Build order {gen1.build_order}") @@ -519,14 +527,18 @@ def test_build_ordering_forward_ref(self, caplog): # caplog fixture captures log content self.setup_log_capture(caplog) - gen1 = dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, - seedColumnName="_id") \ - .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) \ - .withColumn("city_name", "long", minValue=1000000, uniqueValues=10000, random=True) \ - .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city", "struct", - expr="named_struct('name', city_name, 'id', city_id, 'population', city_pop)") \ + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, seedColumnName="_id") + .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) + .withColumn("city_name", "long", minValue=1000000, uniqueValues=10000, random=True) + .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn( + "city", + "struct", + expr="named_struct('name', city_name, 'id', city_id, 'population', city_pop)", + ) .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + ) logging.info(f"Build order {gen1.build_order}") @@ -534,16 +546,20 @@ def test_build_ordering_forward_ref(self, caplog): assert seed_column_warnings_and_errors >= 1, "Should not have error messages about forward references" def test_build_ordering_duplicate_names2(self): - gen1 = dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, - seedColumnName="_id") \ - .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) \ - .withColumn("city_name", "long", minValue=1000000, uniqueValues=10000, random=True) \ - .withColumn("city_name", "string", template=r"\w", random=True, omit=True) \ - .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) \ - .withColumn("city", "struct", - expr="named_struct('name', city_name, 'id', city_id, 'population', city_pop)", - baseColumns=["city_name", "city_id", "city_pop"]) + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=1000, partitions=4, seedColumnName="_id") + .withColumn("id", "long", minValue=1000000, uniqueValues=10000, random=True) + .withColumn("city_name", "long", minValue=1000000, uniqueValues=10000, random=True) + .withColumn("city_name", "string", template=r"\w", random=True, omit=True) + .withColumn("city_id", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn("city_pop", "long", minValue=1000000, uniqueValues=10000, random=True, omit=True) + .withColumn( + "city", + "struct", + expr="named_struct('name', city_name, 'id', city_id, 'population', city_pop)", + baseColumns=["city_name", "city_id", "city_pop"], + ) + ) logging.info(f"Build order {gen1.build_order}") diff --git a/tests/test_columnGenerationSpec.py b/tests/test_columnGenerationSpec.py index c865fc95..fb3025fd 100644 --- a/tests/test_columnGenerationSpec.py +++ b/tests/test_columnGenerationSpec.py @@ -99,13 +99,17 @@ def test_default_random_attribute(self): cd = dg.ColumnGenerationSpec(name="test", colType=StringType(), baseColumn='test0,test_1', expr="concat(1,2)") assert not cd.random, "random should be False by default" - @pytest.mark.parametrize("randomSetting, expectedSetting", - [(True, True), - (False, False), - ]) + @pytest.mark.parametrize( + "randomSetting, expectedSetting", + [ + (True, True), + (False, False), + ], + ) def test_random_explicit(self, randomSetting, expectedSetting): dt = StringType() - cd = dg.ColumnGenerationSpec(name="test", colType=StringType(), baseColumn='test0,test_1', - expr="concat(1,2)", random=randomSetting) + cd = dg.ColumnGenerationSpec( + name="test", colType=StringType(), baseColumn='test0,test_1', expr="concat(1,2)", random=randomSetting + ) assert cd.random is expectedSetting, f"random should be {expectedSetting}" diff --git a/tests/test_complex_columns.py b/tests/test_complex_columns.py index 060de796..36d82800 100644 --- a/tests/test_complex_columns.py +++ b/tests/test_complex_columns.py @@ -2,8 +2,18 @@ import pytest from pyspark.sql import functions as F -from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType, ArrayType, MapType, \ - BinaryType, LongType, DateType +from pyspark.sql.types import ( + StructType, + StructField, + IntegerType, + StringType, + FloatType, + ArrayType, + MapType, + BinaryType, + LongType, + DateType, +) import dbldatagen as dg @@ -31,29 +41,36 @@ def getFieldType(schema, fieldName): else: return None - @pytest.mark.parametrize("complexFieldType, expectedType, invalidValueCondition", - [("array", ArrayType(IntegerType()), "complex_field is not Null"), - ("array>", ArrayType(ArrayType(StringType())), "complex_field is not Null"), - ("map", MapType(StringType(), IntegerType()), "complex_field is not Null"), - ("struct", - StructType([StructField("a", BinaryType()), StructField("b", IntegerType()), - StructField("c", FloatType())]), - "complex_field is not Null" - ) - ]) + @pytest.mark.parametrize( + "complexFieldType, expectedType, invalidValueCondition", + [ + ("array", ArrayType(IntegerType()), "complex_field is not Null"), + ("array>", ArrayType(ArrayType(StringType())), "complex_field is not Null"), + ("map", MapType(StringType(), IntegerType()), "complex_field is not Null"), + ( + "struct", + StructType( + [StructField("a", BinaryType()), StructField("b", IntegerType()), StructField("c", FloatType())] + ), + "complex_field is not Null", + ), + ], + ) def test_uninitialized_complex_fields(self, complexFieldType, expectedType, invalidValueCondition, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows, - partitions=spark.sparkContext.defaultParallelism) - .withIdOutput() - .withColumn("code1", IntegerType(), minValue=100, maxValue=200) - .withColumn("code2", IntegerType(), minValue=0, maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - .withColumn("complex_field", complexFieldType) - ) + df_spec = ( + dg.DataGenerator( + spark, name="test_data_set1", rows=data_rows, partitions=spark.sparkContext.defaultParallelism + ) + .withIdOutput() + .withColumn("code1", IntegerType(), minValue=100, maxValue=200) + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + .withColumn("complex_field", complexFieldType) + ) df = df_spec.build() assert df is not None, "Ensure dataframe can be created" @@ -64,23 +81,27 @@ def test_uninitialized_complex_fields(self, complexFieldType, expectedType, inva invalid_data_count = df.where(invalidValueCondition).count() assert invalid_data_count == 0, "Not expecting invalid values" - @pytest.mark.parametrize("complexFieldType, expectedType, invalidValueCondition", - [("array", ArrayType(IntegerType()), "complex_field is not Null"), - ("array>", ArrayType(ArrayType(StringType())), "complex_field is not Null"), - ("map", MapType(StringType(), IntegerType()), "complex_field is not Null"), - ("struct", - StructType([StructField("a", BinaryType()), StructField("b", IntegerType()), - StructField("c", FloatType())]), - "complex_field is not Null" - ) - ]) + @pytest.mark.parametrize( + "complexFieldType, expectedType, invalidValueCondition", + [ + ("array", ArrayType(IntegerType()), "complex_field is not Null"), + ("array>", ArrayType(ArrayType(StringType())), "complex_field is not Null"), + ("map", MapType(StringType(), IntegerType()), "complex_field is not Null"), + ( + "struct", + StructType( + [StructField("a", BinaryType()), StructField("b", IntegerType()), StructField("c", FloatType())] + ), + "complex_field is not Null", + ), + ], + ) def test_unitialized_complex_fields2(self, complexFieldType, expectedType, invalidValueCondition, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows, - partitions=spark.sparkContext.defaultParallelism) - .withColumn("complex_field", complexFieldType) - ) + df_spec = dg.DataGenerator( + spark, name="test_data_set1", rows=data_rows, partitions=spark.sparkContext.defaultParallelism + ).withColumn("complex_field", complexFieldType) df = df_spec.build() assert df is not None, "Ensure dataframe can be created" @@ -91,41 +112,62 @@ def test_unitialized_complex_fields2(self, complexFieldType, expectedType, inval invalid_data_count = df.where(invalidValueCondition).count() assert invalid_data_count == 0, "Not expecting invalid values" - @pytest.mark.parametrize("complexFieldType, expectedType, valueInit, validCond", - [("array", ArrayType(IntegerType()), "array(1,2,3)", - "complex_field[1] = 2"), - ("array>", ArrayType(ArrayType(StringType())), "array(array('one','two'))", - "complex_field is not Null and size(complex_field) = 1"), - ("map", MapType(StringType(), IntegerType()), "map('hello',1, 'world', 2)", - "complex_field is not Null and complex_field['hello'] = 1"), - ("struct", - StructType([StructField("a", StringType()), StructField("b", IntegerType()), - StructField("c", FloatType())]), - "named_struct('a', 'hello, world', 'b', 42, 'c', 0.25)", - "complex_field is not Null and complex_field.c = 0.25" - ), - ("struct", - StructType([StructField("a", StringType()), StructField("b", IntegerType()), - StructField("c", IntegerType())]), - "named_struct('a', code3, 'b', code1, 'c', code2)", - "complex_field is not Null and complex_field.c = code2" - ) - ]) - def test_initialized_complex_fields(self, complexFieldType, expectedType, valueInit, validCond, setupLogging): \ - # pylint: disable=too-many-positional-arguments + @pytest.mark.parametrize( + "complexFieldType, expectedType, valueInit, validCond", + [ + ("array", ArrayType(IntegerType()), "array(1,2,3)", "complex_field[1] = 2"), + ( + "array>", + ArrayType(ArrayType(StringType())), + "array(array('one','two'))", + "complex_field is not Null and size(complex_field) = 1", + ), + ( + "map", + MapType(StringType(), IntegerType()), + "map('hello',1, 'world', 2)", + "complex_field is not Null and complex_field['hello'] = 1", + ), + ( + "struct", + StructType( + [StructField("a", StringType()), StructField("b", IntegerType()), StructField("c", FloatType())] + ), + "named_struct('a', 'hello, world', 'b', 42, 'c', 0.25)", + "complex_field is not Null and complex_field.c = 0.25", + ), + ( + "struct", + StructType( + [StructField("a", StringType()), StructField("b", IntegerType()), StructField("c", IntegerType())] + ), + "named_struct('a', code3, 'b', code1, 'c', code2)", + "complex_field is not Null and complex_field.c = code2", + ), + ], + ) + def test_initialized_complex_fields( + self, complexFieldType, expectedType, valueInit, validCond, setupLogging + ): # pylint: disable=too-many-positional-arguments data_rows = 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows, - partitions=spark.sparkContext.defaultParallelism) - .withIdOutput() - .withColumn("code1", IntegerType(), minValue=100, maxValue=200) - .withColumn("code2", IntegerType(), minValue=0, maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - .withColumn("complex_field", complexFieldType, expr=valueInit, - baseColumn=['code1', 'code2', 'code3', 'code4', 'code5']) - ) + df_spec = ( + dg.DataGenerator( + spark, name="test_data_set1", rows=data_rows, partitions=spark.sparkContext.defaultParallelism + ) + .withIdOutput() + .withColumn("code1", IntegerType(), minValue=100, maxValue=200) + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + .withColumn( + "complex_field", + complexFieldType, + expr=valueInit, + baseColumn=['code1', 'code2', 'code3', 'code4', 'code5'], + ) + ) df = df_spec.build() assert df is not None, "Ensure dataframe can be created" @@ -139,12 +181,19 @@ def test_initialized_complex_fields(self, complexFieldType, expectedType, valueI def test_basic_arrays_with_columns(self, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows, - partitions=spark.sparkContext.defaultParallelism) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - ) + df_spec = ( + dg.DataGenerator( + spark, name="test_data_set1", rows=data_rows, partitions=spark.sparkContext.defaultParallelism + ) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + ) df = df_spec.build() df.show() @@ -152,12 +201,15 @@ def test_basic_arrays_with_columns(self, setupLogging): def test_basic_arrays_with_columns2(self, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows, - partitions=spark.sparkContext.defaultParallelism) - .withIdOutput() - .withColumn("r", ArrayType(FloatType()), expr="array(floor(rand() * 350) * (86400 + 3600))", - numColumns=column_count) - ) + df_spec = ( + dg.DataGenerator( + spark, name="test_data_set1", rows=data_rows, partitions=spark.sparkContext.defaultParallelism + ) + .withIdOutput() + .withColumn( + "r", ArrayType(FloatType()), expr="array(floor(rand() * 350) * (86400 + 3600))", numColumns=column_count + ) + ) df = df_spec.build() df.show() @@ -165,17 +217,24 @@ def test_basic_arrays_with_columns2(self, setupLogging): def test_basic_arrays_with_columns4(self, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows, - partitions=spark.sparkContext.defaultParallelism) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - .withColumn("code1", IntegerType(), minValue=100, maxValue=200) - .withColumn("code2", IntegerType(), minValue=0, maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - ) + df_spec = ( + dg.DataGenerator( + spark, name="test_data_set1", rows=data_rows, partitions=spark.sparkContext.defaultParallelism + ) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + .withColumn("code1", IntegerType(), minValue=100, maxValue=200) + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) df = df_spec.build() df.show() @@ -183,17 +242,20 @@ def test_basic_arrays_with_columns4(self, setupLogging): def test_basic_arrays_with_columns5(self, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows, - partitions=spark.sparkContext.defaultParallelism) - .withIdOutput() - .withColumn("r", FloatType(), minValue=1.0, maxValue=10.0, step=0.1, - numColumns=column_count, structType="array") - .withColumn("code1", IntegerType(), minValue=100, maxValue=200) - .withColumn("code2", IntegerType(), minValue=0, maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - ) + df_spec = ( + dg.DataGenerator( + spark, name="test_data_set1", rows=data_rows, partitions=spark.sparkContext.defaultParallelism + ) + .withIdOutput() + .withColumn( + "r", FloatType(), minValue=1.0, maxValue=10.0, step=0.1, numColumns=column_count, structType="array" + ) + .withColumn("code1", IntegerType(), minValue=100, maxValue=200) + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) df = df_spec.build() df.show() @@ -210,50 +272,55 @@ def arraySchema(self): def test_basic_arrays_with_existing_schema(self, arraySchema, setupLogging): print(f"schema: {arraySchema}") - gen1 = (dg.DataGenerator(sparkSession=spark, name="array_schema", rows=10, partitions=2) - .withSchema(arraySchema) - .withColumn("anotherValue") - ) + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="array_schema", rows=10, partitions=2) + .withSchema(arraySchema) + .withColumn("anotherValue") + ) df = gen1.build() df.show() def test_basic_arrays_with_existing_schema2(self, arraySchema, setupLogging): print(f"schema: {arraySchema}") - gen1 = (dg.DataGenerator(sparkSession=spark, name="array_schema", rows=10, partitions=2) - .withSchema(arraySchema) - .withColumnSpec("arrayVal", numColumns=4, structType="array") - ) + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="array_schema", rows=10, partitions=2) + .withSchema(arraySchema) + .withColumnSpec("arrayVal", numColumns=4, structType="array") + ) df = gen1.build() df.show() def test_basic_arrays_with_existing_schema3(self, arraySchema, setupLogging): print(f"schema: {arraySchema}") - gen1 = (dg.DataGenerator(sparkSession=spark, name="array_schema", rows=10, partitions=2) - .withSchema(arraySchema) - .withColumnSpec("arrayVal", expr="array(1,2,3)") - ) + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="array_schema", rows=10, partitions=2) + .withSchema(arraySchema) + .withColumnSpec("arrayVal", expr="array(1,2,3)") + ) df = gen1.build() df.show() def test_basic_arrays_with_existing_schema4(self, arraySchema, setupLogging): print(f"schema: {arraySchema}") - gen1 = (dg.DataGenerator(sparkSession=spark, name="array_schema", rows=10, partitions=2) - .withSchema(arraySchema) - .withColumnSpec("arrayVal", expr="array(1,2,3)", numColumns=4, structType="array") - ) + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="array_schema", rows=10, partitions=2) + .withSchema(arraySchema) + .withColumnSpec("arrayVal", expr="array(1,2,3)", numColumns=4, structType="array") + ) df = gen1.build() df.show() def test_basic_arrays_with_existing_schema6(self, arraySchema, setupLogging): print(f"schema: {arraySchema}") - gen1 = (dg.DataGenerator(sparkSession=spark, name="array_schema", rows=10, partitions=2) - .withSchema(arraySchema) - .withColumnSpec("arrayVal", expr="array(id+1)") - ) + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="array_schema", rows=10, partitions=2) + .withSchema(arraySchema) + .withColumnSpec("arrayVal", expr="array(id+1)") + ) df = gen1.build() assert df is not None df.show() @@ -261,43 +328,56 @@ def test_basic_arrays_with_existing_schema6(self, arraySchema, setupLogging): def test_use_of_struct_in_schema1(self, setupLogging): # while this is not ideal form, ensure that it is tolerated to address reported issue # note there is no initializer for the struct and there is an override of the default `id` field - struct_type = StructType([ - StructField('id', LongType(), True), - StructField("city", StructType([ + struct_type = StructType( + [ StructField('id', LongType(), True), - StructField('population', LongType(), True) - ]), True)]) + StructField( + "city", + StructType([StructField('id', LongType(), True), StructField('population', LongType(), True)]), + True, + ), + ] + ) - gen1 = (dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=10000, partitions=4) - .withSchema(struct_type) - .withColumn("id") - ) + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=10000, partitions=4) + .withSchema(struct_type) + .withColumn("id") + ) res1 = gen1.build(withTempView=True) assert res1.count() == 10000 def test_use_of_struct_in_schema2(self, setupLogging): - struct_type = StructType([ - StructField('id', LongType(), True), - StructField("city", StructType([ + struct_type = StructType( + [ StructField('id', LongType(), True), - StructField('population', LongType(), True) - ]), True)]) + StructField( + "city", + StructType([StructField('id', LongType(), True), StructField('population', LongType(), True)]), + True, + ), + ] + ) - gen1 = (dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=10000, partitions=4) - .withSchema(struct_type) - .withColumnSpec("city", expr="named_struct('id', id, 'population', id * 1000)") - ) + gen1 = ( + dg.DataGenerator(sparkSession=spark, name="nested_schema", rows=10000, partitions=4) + .withSchema(struct_type) + .withColumnSpec("city", expr="named_struct('id', id, 'population', id * 1000)") + ) res1 = gen1.build(withTempView=True) assert res1.count() == 10000 def test_varying_arrays(self, setupLogging): - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=1000, random=True) - .withColumn("r", "float", minValue=1.0, maxValue=10.0, step=0.1, - numColumns=5) - .withColumn("observations", "array", - expr="slice(array(r_0, r_1, r_2, r_3, r_4), 1, abs(hash(id)) % 5 + 1 )", - baseColumn="r") - ) + df_spec = ( + dg.DataGenerator(spark, name="test_data_set1", rows=1000, random=True) + .withColumn("r", "float", minValue=1.0, maxValue=10.0, step=0.1, numColumns=5) + .withColumn( + "observations", + "array", + expr="slice(array(r_0, r_1, r_2, r_3, r_4), 1, abs(hash(id)) % 5 + 1 )", + baseColumn="r", + ) + ) df = df_spec.build() df.show() @@ -339,9 +419,7 @@ def test_single_element_array(self): df_spec = df_spec.withColumn( "test3", "string", structType="array", numFeatures=(1, 1), values=["one", "two", "three"] ) - df_spec = df_spec.withColumn( - "test4", "string", structType="array", values=["one", "two", "three"] - ) + df_spec = df_spec.withColumn("test4", "string", structType="array", values=["one", "two", "three"]) test_df = df_spec.build() @@ -381,24 +459,21 @@ def test_map_values(self): ], ) df_spec = df_spec.withColumn( - "v4", - "string", - values=["this", "is", "a", "test"], - numFeatures=1, - structType="array" + "v4", "string", values=["this", "is", "a", "test"], numFeatures=1, structType="array" ) df_spec = df_spec.withColumn( "test", "map", - values=[F.map_from_arrays(F.col("v1"), F.col("v2")), - F.map_from_arrays(F.col("v1"), F.col("v3")), - F.map_from_arrays(F.col("v2"), F.col("v3")), - F.map_from_arrays(F.col("v1"), F.col("v4")), - F.map_from_arrays(F.col("v2"), F.col("v4")), - F.map_from_arrays(F.col("v3"), F.col("v4")) - ], - baseColumns=["v1", "v2", "v3", "v4"] + values=[ + F.map_from_arrays(F.col("v1"), F.col("v2")), + F.map_from_arrays(F.col("v1"), F.col("v3")), + F.map_from_arrays(F.col("v2"), F.col("v3")), + F.map_from_arrays(F.col("v1"), F.col("v4")), + F.map_from_arrays(F.col("v2"), F.col("v4")), + F.map_from_arrays(F.col("v3"), F.col("v4")), + ], + baseColumns=["v1", "v2", "v3", "v4"], ) test_df = df_spec.build() @@ -411,16 +486,22 @@ def test_inferred_column_types_disallowed1(self, setupLogging): with pytest.raises(ValueError): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - .withColumn("code1", "integer", minValue=100, maxValue=200) - .withColumn("code2", "integer", minValue=0, maxValue=10) - .withColumn("code3", dg.INFER_DATATYPE, values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - ) + df_spec = ( + dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + .withColumn("code1", "integer", minValue=100, maxValue=200) + .withColumn("code2", "integer", minValue=0, maxValue=10) + .withColumn("code3", dg.INFER_DATATYPE, values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) df = df_spec.build() @@ -428,20 +509,19 @@ def test_inferred_column_types_disallowed1(self, setupLogging): def test_inferred_disallowed_with_schema(self): """Test use of schema""" - schema = StructType([ - StructField("region_id", IntegerType(), True), - StructField("region_cd", StringType(), True), - StructField("c", StringType(), True), - StructField("c1", StringType(), True), - StructField("state1", StringType(), True), - StructField("state2", StringType(), True), - StructField("st_desc", StringType(), True), - - ]) + schema = StructType( + [ + StructField("region_id", IntegerType(), True), + StructField("region_cd", StringType(), True), + StructField("c", StringType(), True), + StructField("c1", StringType(), True), + StructField("state1", StringType(), True), + StructField("state2", StringType(), True), + StructField("st_desc", StringType(), True), + ] + ) - testDataSpec = (dg.DataGenerator(spark, name="test_data_set1", rows=10000) - .withSchema(schema) - ) + testDataSpec = dg.DataGenerator(spark, name="test_data_set1", rows=10000).withSchema(schema) with pytest.raises(ValueError): testDataSpec2 = testDataSpec.withColumnSpecs(matchTypes=[dg.INFER_DATATYPE], minValue=0, maxValue=100) @@ -451,18 +531,24 @@ def test_inferred_disallowed_with_schema(self): def test_inferred_column_basic(self, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - .withColumn("code1", "integer", minValue=100, maxValue=200) - .withColumn("code2", "integer", minValue=0, maxValue=10) - .withColumn("code3", "string", values=['one', 'two', 'three']) - .withColumn("code4", "string", values=['one', 'two', 'three']) - .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") - .withColumn("code6", dg.INFER_DATATYPE, expr="code1 + code2") - .withColumn("code7", dg.INFER_DATATYPE, expr="concat(code3, code4)") - ) + df_spec = ( + dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + .withColumn("code1", "integer", minValue=100, maxValue=200) + .withColumn("code2", "integer", minValue=0, maxValue=10) + .withColumn("code3", "string", values=['one', 'two', 'three']) + .withColumn("code4", "string", values=['one', 'two', 'three']) + .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") + .withColumn("code6", dg.INFER_DATATYPE, expr="code1 + code2") + .withColumn("code7", dg.INFER_DATATYPE, expr="concat(code3, code4)") + ) columnSpec1 = df_spec.getColumnSpec("code1") assert columnSpec1.inferDatatype is False @@ -479,18 +565,24 @@ def test_inferred_column_basic(self, setupLogging): def test_inferred_column_validate_types(self, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - .withColumn("code1", "integer", minValue=100, maxValue=200) - .withColumn("code2", "integer", minValue=0, maxValue=10) - .withColumn("code3", "string", values=['one', 'two', 'three']) - .withColumn("code4", "string", values=['one', 'two', 'three']) - .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") - .withColumn("code6", dg.INFER_DATATYPE, expr="code1 + code2") - .withColumn("code7", dg.INFER_DATATYPE, expr="concat(code3, code4)") - ) + df_spec = ( + dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + .withColumn("code1", "integer", minValue=100, maxValue=200) + .withColumn("code2", "integer", minValue=0, maxValue=10) + .withColumn("code3", "string", values=['one', 'two', 'three']) + .withColumn("code4", "string", values=['one', 'two', 'three']) + .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") + .withColumn("code6", dg.INFER_DATATYPE, expr="code1 + code2") + .withColumn("code7", dg.INFER_DATATYPE, expr="concat(code3, code4)") + ) df = df_spec.build() @@ -506,19 +598,25 @@ def test_inferred_column_validate_types(self, setupLogging): def test_inferred_column_structs1(self, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - .withColumn("code1", "integer", minValue=100, maxValue=200) - .withColumn("code2", "integer", minValue=0, maxValue=10) - .withColumn("code3", "string", values=['one', 'two', 'three']) - .withColumn("code4", "string", values=['one', 'two', 'three']) - .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") - .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") - .withColumn("struct1", dg.INFER_DATATYPE, expr="named_struct('a', code1, 'b', code2)") - .withColumn("struct2", dg.INFER_DATATYPE, expr="named_struct('a', code5, 'b', code6)") - ) + df_spec = ( + dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + .withColumn("code1", "integer", minValue=100, maxValue=200) + .withColumn("code2", "integer", minValue=0, maxValue=10) + .withColumn("code3", "string", values=['one', 'two', 'three']) + .withColumn("code4", "string", values=['one', 'two', 'three']) + .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") + .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") + .withColumn("struct1", dg.INFER_DATATYPE, expr="named_struct('a', code1, 'b', code2)") + .withColumn("struct2", dg.INFER_DATATYPE, expr="named_struct('a', code5, 'b', code6)") + ) df = df_spec.build() @@ -533,20 +631,26 @@ def test_inferred_column_structs1(self, setupLogging): def test_inferred_column_structs2(self, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - .withColumn("code1", "integer", minValue=100, maxValue=200) - .withColumn("code2", "integer", minValue=0, maxValue=10) - .withColumn("code3", "string", values=['one', 'two', 'three']) - .withColumn("code4", "string", values=['one', 'two', 'three']) - .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") - .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") - .withColumn("struct1", dg.INFER_DATATYPE, expr="named_struct('a', code1, 'b', code2)") - .withColumn("struct2", dg.INFER_DATATYPE, expr="named_struct('a', code5, 'b', code6)") - .withColumn("struct3", dg.INFER_DATATYPE, expr="named_struct('a', struct1, 'b', struct2)") - ) + df_spec = ( + dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + .withColumn("code1", "integer", minValue=100, maxValue=200) + .withColumn("code2", "integer", minValue=0, maxValue=10) + .withColumn("code3", "string", values=['one', 'two', 'three']) + .withColumn("code4", "string", values=['one', 'two', 'three']) + .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") + .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") + .withColumn("struct1", dg.INFER_DATATYPE, expr="named_struct('a', code1, 'b', code2)") + .withColumn("struct2", dg.INFER_DATATYPE, expr="named_struct('a', code5, 'b', code6)") + .withColumn("struct3", dg.INFER_DATATYPE, expr="named_struct('a', struct1, 'b', struct2)") + ) df = df_spec.build() @@ -556,26 +660,40 @@ def test_inferred_column_structs2(self, setupLogging): assert type2 == StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)]) type3 = self.getFieldType(df.schema, "struct3") assert type3 == StructType( - [StructField('a', StructType([StructField('a', IntegerType(), True), StructField('b', IntegerType(), True)]), False), - StructField('b', StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)]), False)] + [ + StructField( + 'a', + StructType([StructField('a', IntegerType(), True), StructField('b', IntegerType(), True)]), + False, + ), + StructField( + 'b', StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)]), False + ), + ] ) def test_with_struct_column1(self, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - .withColumn("code1", "integer", minValue=100, maxValue=200) - .withColumn("code2", "integer", minValue=0, maxValue=10) - .withColumn("code3", "string", values=['one', 'two', 'three']) - .withColumn("code4", "string", values=['one', 'two', 'three']) - .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") - .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") - .withStructColumn("struct1", fields=[('a', 'code1'), ('b', 'code2')]) - .withStructColumn("struct2", fields=[('a', 'code5'), ('b', 'code6')]) - ) + df_spec = ( + dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + .withColumn("code1", "integer", minValue=100, maxValue=200) + .withColumn("code2", "integer", minValue=0, maxValue=10) + .withColumn("code3", "string", values=['one', 'two', 'three']) + .withColumn("code4", "string", values=['one', 'two', 'three']) + .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") + .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") + .withStructColumn("struct1", fields=[('a', 'code1'), ('b', 'code2')]) + .withStructColumn("struct2", fields=[('a', 'code5'), ('b', 'code6')]) + ) df = df_spec.build() @@ -587,43 +705,57 @@ def test_with_struct_column1(self, setupLogging): def test_with_struct_column2(self, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - .withColumn("code1", "integer", minValue=100, maxValue=200) - .withColumn("code2", "integer", minValue=0, maxValue=10) - .withColumn("code3", "string", values=['one', 'two', 'three']) - .withColumn("code4", "string", values=['one', 'two', 'three']) - .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") - .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") - .withStructColumn("struct1", fields=['code1', 'code2']) - .withStructColumn("struct2", fields=['code5', 'code6']) - ) + df_spec = ( + dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + .withColumn("code1", "integer", minValue=100, maxValue=200) + .withColumn("code2", "integer", minValue=0, maxValue=10) + .withColumn("code3", "string", values=['one', 'two', 'three']) + .withColumn("code4", "string", values=['one', 'two', 'three']) + .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") + .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") + .withStructColumn("struct1", fields=['code1', 'code2']) + .withStructColumn("struct2", fields=['code5', 'code6']) + ) df = df_spec.build() type1 = self.getFieldType(df.schema, "struct1") - assert type1 == StructType([StructField('code1', IntegerType(), True), StructField('code2', IntegerType(), True)]) + assert type1 == StructType( + [StructField('code1', IntegerType(), True), StructField('code2', IntegerType(), True)] + ) type2 = self.getFieldType(df.schema, "struct2") assert type2 == StructType([StructField('code5', DateType(), False), StructField('code6', StringType(), False)]) def test_with_json_struct_column(self, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - .withColumn("code1", "integer", minValue=100, maxValue=200) - .withColumn("code2", "integer", minValue=0, maxValue=10) - .withColumn("code3", "string", values=['one', 'two', 'three']) - .withColumn("code4", "string", values=['one', 'two', 'three']) - .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") - .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") - .withStructColumn("struct1", fields=['code1', 'code2'], asJson=True) - .withStructColumn("struct2", fields=['code5', 'code6'], asJson=True) - ) + df_spec = ( + dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + .withColumn("code1", "integer", minValue=100, maxValue=200) + .withColumn("code2", "integer", minValue=0, maxValue=10) + .withColumn("code3", "string", values=['one', 'two', 'three']) + .withColumn("code4", "string", values=['one', 'two', 'three']) + .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") + .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") + .withStructColumn("struct1", fields=['code1', 'code2'], asJson=True) + .withStructColumn("struct2", fields=['code5', 'code6'], asJson=True) + ) df = df_spec.build() @@ -636,19 +768,25 @@ def test_with_json_struct_column(self, setupLogging): def test_with_json_struct_column2(self, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - .withColumn("code1", "integer", minValue=100, maxValue=200) - .withColumn("code2", "integer", minValue=0, maxValue=10) - .withColumn("code3", "string", values=['one', 'two', 'three']) - .withColumn("code4", "string", values=['one', 'two', 'three']) - .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") - .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") - .withStructColumn("struct1", fields={'codes': ["code6", "code6"]}, asJson=True) - .withStructColumn("struct2", fields=['code5', 'code6'], asJson=True) - ) + df_spec = ( + dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + .withColumn("code1", "integer", minValue=100, maxValue=200) + .withColumn("code2", "integer", minValue=0, maxValue=10) + .withColumn("code3", "string", values=['one', 'two', 'three']) + .withColumn("code4", "string", values=['one', 'two', 'three']) + .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") + .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") + .withStructColumn("struct1", fields={'codes': ["code6", "code6"]}, asJson=True) + .withStructColumn("struct2", fields=['code5', 'code6'], asJson=True) + ) df = df_spec.build() @@ -662,20 +800,26 @@ def test_with_json_struct_column2(self, setupLogging): def test_with_struct_column3(self, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - .withColumn("code1", "integer", minValue=100, maxValue=200) - .withColumn("code2", "integer", minValue=0, maxValue=10) - .withColumn("code3", "string", values=['one', 'two', 'three']) - .withColumn("code4", "string", values=['one', 'two', 'three']) - .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") - .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") - .withStructColumn("struct1", fields=[('a', 'code1'), ('b', 'code2')]) - .withStructColumn("struct2", fields=[('a', 'code5'), ('b', 'code6')]) - .withStructColumn("struct3", fields=[('a', 'struct1'), ('b', 'struct2')]) - ) + df_spec = ( + dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + .withColumn("code1", "integer", minValue=100, maxValue=200) + .withColumn("code2", "integer", minValue=0, maxValue=10) + .withColumn("code3", "string", values=['one', 'two', 'three']) + .withColumn("code4", "string", values=['one', 'two', 'three']) + .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") + .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") + .withStructColumn("struct1", fields=[('a', 'code1'), ('b', 'code2')]) + .withStructColumn("struct2", fields=[('a', 'code5'), ('b', 'code6')]) + .withStructColumn("struct3", fields=[('a', 'struct1'), ('b', 'struct2')]) + ) df = df_spec.build() @@ -685,28 +829,41 @@ def test_with_struct_column3(self, setupLogging): assert type2 == StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)]) type3 = self.getFieldType(df.schema, "struct3") assert type3 == StructType( - [StructField('a', StructType([StructField('a', IntegerType(), True), StructField('b', IntegerType(), True)]), False), - StructField('b', StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)]), - False)]) + [ + StructField( + 'a', + StructType([StructField('a', IntegerType(), True), StructField('b', IntegerType(), True)]), + False, + ), + StructField( + 'b', StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)]), False + ), + ] + ) def test_with_struct_column4(self, setupLogging): column_count = 10 data_rows = 10 * 1000 - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - .withColumn("code1", "integer", minValue=100, maxValue=200) - .withColumn("code2", "integer", minValue=0, maxValue=10) - .withColumn("code3", "string", values=['one', 'two', 'three']) - .withColumn("code4", "string", values=['one', 'two', 'three']) - .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") - .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") - .withStructColumn("struct1", fields=[('a', 'code1'), ('b', 'code2')]) - .withStructColumn("struct2", fields=[('a', 'code5'), ('b', 'code6')]) - .withStructColumn("struct3", - fields={'a': {'a': 'code1', 'b': 'code2'}, 'b': {'a': 'code5', 'b': 'code6'}}) - ) + df_spec = ( + dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + .withColumn("code1", "integer", minValue=100, maxValue=200) + .withColumn("code2", "integer", minValue=0, maxValue=10) + .withColumn("code3", "string", values=['one', 'two', 'three']) + .withColumn("code4", "string", values=['one', 'two', 'three']) + .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") + .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") + .withStructColumn("struct1", fields=[('a', 'code1'), ('b', 'code2')]) + .withStructColumn("struct2", fields=[('a', 'code5'), ('b', 'code6')]) + .withStructColumn("struct3", fields={'a': {'a': 'code1', 'b': 'code2'}, 'b': {'a': 'code5', 'b': 'code6'}}) + ) df = df_spec.build() @@ -716,27 +873,41 @@ def test_with_struct_column4(self, setupLogging): assert type2 == StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)]) type3 = self.getFieldType(df.schema, "struct3") assert type3 == StructType( - [StructField('a', StructType([StructField('a', IntegerType(), True), StructField('b', IntegerType(), True)]), False), - StructField('b', StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)]), - False)]) + [ + StructField( + 'a', + StructType([StructField('a', IntegerType(), True), StructField('b', IntegerType(), True)]), + False, + ), + StructField( + 'b', StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)]), False + ), + ] + ) def test_with_struct_column_err1(self, setupLogging): column_count = 10 data_rows = 10 * 1000 with pytest.raises(ValueError): - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - .withColumn("code1", "integer", minValue=100, maxValue=200) - .withColumn("code2", "integer", minValue=0, maxValue=10) - .withColumn("code3", "string", values=['one', 'two', 'three']) - .withColumn("code4", "string", values=['one', 'two', 'three']) - .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") - .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") - .withStructColumn("struct1", fields={'BAD_FIELD': 45}) - ) + df_spec = ( + dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + .withColumn("code1", "integer", minValue=100, maxValue=200) + .withColumn("code2", "integer", minValue=0, maxValue=10) + .withColumn("code3", "string", values=['one', 'two', 'three']) + .withColumn("code4", "string", values=['one', 'two', 'three']) + .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") + .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") + .withStructColumn("struct1", fields={'BAD_FIELD': 45}) + ) df = df_spec.build() @@ -745,17 +916,23 @@ def test_with_struct_column_err2(self, setupLogging): data_rows = 10 * 1000 with pytest.raises(Exception): - df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=column_count, structType="array") - .withColumn("code1", "integer", minValue=100, maxValue=200) - .withColumn("code2", "integer", minValue=0, maxValue=10) - .withColumn("code3", "string", values=['one', 'two', 'three']) - .withColumn("code4", "string", values=['one', 'two', 'three']) - .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") - .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") - .withStructColumn("struct1", fields=23) - ) + df_spec = ( + dg.DataGenerator(spark, name="test_data_set1", rows=data_rows) + .withIdOutput() + .withColumn( + "r", + FloatType(), + expr="floor(rand() * 350) * (86400 + 3600)", + numColumns=column_count, + structType="array", + ) + .withColumn("code1", "integer", minValue=100, maxValue=200) + .withColumn("code2", "integer", minValue=0, maxValue=10) + .withColumn("code3", "string", values=['one', 'two', 'three']) + .withColumn("code4", "string", values=['one', 'two', 'three']) + .withColumn("code5", dg.INFER_DATATYPE, expr="current_date()") + .withColumn("code6", dg.INFER_DATATYPE, expr="concat(code3, code4)") + .withStructColumn("struct1", fields=23) + ) df = df_spec.build() diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 6928c68b..0f88b00e 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -4,8 +4,17 @@ from pyspark.sql.types import IntegerType, StringType, FloatType import dbldatagen as dg -from dbldatagen.constraints import SqlExpr, LiteralRelation, ChainedRelation, LiteralRange, RangedValues, \ - PositiveValues, NegativeValues, UniqueCombinations, Constraint +from dbldatagen.constraints import ( + SqlExpr, + LiteralRelation, + ChainedRelation, + LiteralRange, + RangedValues, + PositiveValues, + NegativeValues, + UniqueCombinations, + Constraint, +) spark = dg.SparkSingleton.getLocalInstance("unit tests") @@ -20,26 +29,23 @@ class TestConstraints: @pytest.fixture() def generationSpec1(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, - partitions=4) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)") - .withColumn("code1", IntegerType(), unique_values=100) - .withColumn("code2", IntegerType(), min=1, max=200) - .withColumn("code3", IntegerType(), maxValue=10) - .withColumn("positive_and_negative", IntegerType(), minValue=-100, maxValue=100) - .withColumn("code4", StringType(), values=['a', 'b', 'c']) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code6", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)") + .withColumn("code1", IntegerType(), unique_values=100) + .withColumn("code2", IntegerType(), min=1, max=200) + .withColumn("code3", IntegerType(), maxValue=10) + .withColumn("positive_and_negative", IntegerType(), minValue=-100, maxValue=100) + .withColumn("code4", StringType(), values=['a', 'b', 'c']) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code6", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) return testDataSpec def test_simple_constraints(self, generationSpec1): - testDataSpec = (generationSpec1 - .withSqlConstraint("id < 100") - .withSqlConstraint("id > 0") - ) + testDataSpec = generationSpec1.withSqlConstraint("id < 100").withSqlConstraint("id > 0") testDataDF = testDataSpec.build() @@ -47,9 +53,7 @@ def test_simple_constraints(self, generationSpec1): assert rowCount == 99 def test_simple_constraints2(self, generationSpec1): - testDataSpec = (generationSpec1 - .withConstraint(SqlExpr("id < 100")) - ) + testDataSpec = generationSpec1.withConstraint(SqlExpr("id < 100")) testDataDF = testDataSpec.build() @@ -57,10 +61,7 @@ def test_simple_constraints2(self, generationSpec1): assert rowCount == 100 def test_multiple_constraints(self, generationSpec1): - testDataSpec = (generationSpec1 - .withConstraints([SqlExpr("id < 100"), - SqlExpr("id > 0")]) - ) + testDataSpec = generationSpec1.withConstraints([SqlExpr("id < 100"), SqlExpr("id > 0")]) testDataDF = testDataSpec.build() @@ -69,20 +70,20 @@ def test_multiple_constraints(self, generationSpec1): def test_streaming_exception(self, generationSpec1): with pytest.raises(RuntimeError): - testDataSpec = (generationSpec1 - .withConstraint(UniqueCombinations(["code1", "code2"])) - ) + testDataSpec = generationSpec1.withConstraint(UniqueCombinations(["code1", "code2"])) testDataDF = testDataSpec.build(withStreaming=True) assert testDataDF is not None - @pytest.mark.parametrize("constraints,producesExpression", - [ - ([SqlExpr("id < 100"), SqlExpr("id > 0")], True), - ([UniqueCombinations()], False), - ([UniqueCombinations("*")], False), - ([UniqueCombinations(["a", "b"])], False), - ]) + @pytest.mark.parametrize( + "constraints,producesExpression", + [ + ([SqlExpr("id < 100"), SqlExpr("id > 0")], True), + ([UniqueCombinations()], False), + ([UniqueCombinations("*")], False), + ([UniqueCombinations(["a", "b"])], False), + ], + ) def test_combine_constraints(self, constraints, producesExpression): constraintExpressions = [c.filterExpression for c in constraints] @@ -100,61 +101,60 @@ def test_constraint_filter_expression_cache(self): filterExpression2 = constraint.filterExpression assert filterExpression is filterExpression2 - @pytest.mark.parametrize("column, operation, literalValue, expectedRows", - [ - ("id", "<", 50, 49), - ("id", "<=", 50, 50), - ("id", ">", 50, 49), - ("id", ">=", 50, 50), - ("id", "==", 50, 1), - ("id", "!=", 50, 98), - ]) - def test_scalar_relation(self, column, operation, literalValue, expectedRows, generationSpec1): \ - # pylint: disable=too-many-positional-arguments - - testDataSpec = (generationSpec1 - .withConstraints([SqlExpr("id < 100"), - SqlExpr("id > 0")]) - .withConstraint(LiteralRelation(column, operation, literalValue)) - ) + @pytest.mark.parametrize( + "column, operation, literalValue, expectedRows", + [ + ("id", "<", 50, 49), + ("id", "<=", 50, 50), + ("id", ">", 50, 49), + ("id", ">=", 50, 50), + ("id", "==", 50, 1), + ("id", "!=", 50, 98), + ], + ) + def test_scalar_relation( + self, column, operation, literalValue, expectedRows, generationSpec1 + ): # pylint: disable=too-many-positional-arguments + + testDataSpec = generationSpec1.withConstraints([SqlExpr("id < 100"), SqlExpr("id > 0")]).withConstraint( + LiteralRelation(column, operation, literalValue) + ) testDataDF = testDataSpec.build() rowCount = testDataDF.count() assert rowCount == expectedRows - @pytest.mark.parametrize("columns, strictFlag, expectedRows", - [ - ("positive_and_negative", True, 99), - ("positive_and_negative", False, 100), - ("positive_and_negative", "skip", 100), - ]) + @pytest.mark.parametrize( + "columns, strictFlag, expectedRows", + [ + ("positive_and_negative", True, 99), + ("positive_and_negative", False, 100), + ("positive_and_negative", "skip", 100), + ], + ) def testNegativeValues(self, generationSpec1, columns, strictFlag, expectedRows): - testDataSpec = (generationSpec1 - .withConstraints([SqlExpr("id < 100"), - SqlExpr("id > 0")]) - .withConstraint(NegativeValues(columns, strict=strictFlag) if strictFlag != "skip" - else NegativeValues(columns)) - ) + testDataSpec = generationSpec1.withConstraints([SqlExpr("id < 100"), SqlExpr("id > 0")]).withConstraint( + NegativeValues(columns, strict=strictFlag) if strictFlag != "skip" else NegativeValues(columns) + ) testDataDF = testDataSpec.build() rowCount = testDataDF.count() assert rowCount == 99 - @pytest.mark.parametrize("columns, strictFlag, expectedRows", - [ - ("positive_and_negative", True, 99), - ("positive_and_negative", False, 100), - ("positive_and_negative", "skip", 100), - ]) + @pytest.mark.parametrize( + "columns, strictFlag, expectedRows", + [ + ("positive_and_negative", True, 99), + ("positive_and_negative", False, 100), + ("positive_and_negative", "skip", 100), + ], + ) def testPositiveValues(self, generationSpec1, columns, strictFlag, expectedRows): - testDataSpec = (generationSpec1 - .withConstraints([SqlExpr("id < 200"), - SqlExpr("id > 0")]) - .withConstraint(PositiveValues(columns, strict=strictFlag) if strictFlag != "skip" - else PositiveValues(columns)) - ) + testDataSpec = generationSpec1.withConstraints([SqlExpr("id < 200"), SqlExpr("id > 0")]).withConstraint( + PositiveValues(columns, strict=strictFlag) if strictFlag != "skip" else PositiveValues(columns) + ) testDataDF = testDataSpec.build() @@ -163,11 +163,9 @@ def testPositiveValues(self, generationSpec1, columns, strictFlag, expectedRows) def test_scalar_relation_bad(self, generationSpec1): with pytest.raises(ValueError): - testDataSpec = (generationSpec1 - .withConstraints([SqlExpr("id < 100"), - SqlExpr("id > 0")]) - .withConstraint(LiteralRelation("id", "<<<", 50)) - ) + testDataSpec = generationSpec1.withConstraints([SqlExpr("id < 100"), SqlExpr("id > 0")]).withConstraint( + LiteralRelation("id", "<<<", 50) + ) testDataDF = testDataSpec.build() @@ -176,50 +174,50 @@ def test_scalar_relation_bad(self, generationSpec1): @pytest.fixture() def generationSpec2(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, - partitions=4) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)") - .withColumn("code1", IntegerType(), min=1, max=100) - .withColumn("code2", IntegerType(), min=50, max=150) - .withColumn("code3", IntegerType(), min=100, max=200) - .withColumn("code4", IntegerType(), min=1, max=300) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)") + .withColumn("code1", IntegerType(), min=1, max=100) + .withColumn("code2", IntegerType(), min=50, max=150) + .withColumn("code3", IntegerType(), min=100, max=200) + .withColumn("code4", IntegerType(), min=1, max=300) + ) return testDataSpec - @pytest.mark.parametrize("columns, operation, expectedRows", - [ - (["code1", "code2", "code3"], "<", 99), - (["code1", "code2", "code3"], "<=", 99), - (["code3", "code2", "code1"], ">", 99), - (["code3", "code2", "code1"], ">=", 99), - ]) + @pytest.mark.parametrize( + "columns, operation, expectedRows", + [ + (["code1", "code2", "code3"], "<", 99), + (["code1", "code2", "code3"], "<=", 99), + (["code3", "code2", "code1"], ">", 99), + (["code3", "code2", "code1"], ">=", 99), + ], + ) def test_chained_relation(self, generationSpec2, columns, operation, expectedRows): - testDataSpec = (generationSpec2 - .withConstraints([SqlExpr("id < 100"), - SqlExpr("id > 0")]) - .withConstraint(ChainedRelation(columns, operation)) - ) + testDataSpec = generationSpec2.withConstraints([SqlExpr("id < 100"), SqlExpr("id > 0")]).withConstraint( + ChainedRelation(columns, operation) + ) testDataDF = testDataSpec.build() rowCount = testDataDF.count() assert rowCount == expectedRows - @pytest.mark.parametrize("columns, operation", - [ - (["code3", "code2", "code1"], "<<<"), - (None, "<="), - (["code3"], ">"), - ]) + @pytest.mark.parametrize( + "columns, operation", + [ + (["code3", "code2", "code1"], "<<<"), + (None, "<="), + (["code3"], ">"), + ], + ) def test_chained_relation_bad(self, generationSpec2, columns, operation): with pytest.raises(ValueError): - testDataSpec = (generationSpec2 - .withConstraints([SqlExpr("id < 100"), - SqlExpr("id > 0")]) - .withConstraint(ChainedRelation(columns, operation)) - ) + testDataSpec = generationSpec2.withConstraints([SqlExpr("id < 100"), SqlExpr("id > 0")]).withConstraint( + ChainedRelation(columns, operation) + ) testDataDF = testDataSpec.build() @@ -231,9 +229,7 @@ def test_unique_combinations(self, generationSpec2): validationCount2 = df.dropDuplicates(['code1', 'code4']).count() print(validationCount, validationCount2) - testDataSpec = (generationSpec2 - .withConstraint(UniqueCombinations(["code1", "code4"])) - ) + testDataSpec = generationSpec2.withConstraint(UniqueCombinations(["code1", "code4"])) testDataDF = testDataSpec.build() @@ -242,13 +238,13 @@ def test_unique_combinations(self, generationSpec2): @pytest.fixture() def generationSpec3(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, - partitions=4) - .withColumn("code1", IntegerType(), min=1, max=20) - .withColumn("code2", IntegerType(), min=1, max=30) - .withColumn("code3", IntegerType(), min=1, max=5) - .withColumn("code4", IntegerType(), min=1, max=10) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withColumn("code1", IntegerType(), min=1, max=20) + .withColumn("code2", IntegerType(), min=1, max=30) + .withColumn("code3", IntegerType(), min=1, max=5) + .withColumn("code4", IntegerType(), min=1, max=10) + ) return testDataSpec @@ -260,9 +256,7 @@ def test_unique_combinations2(self, generationSpec3): validationCount2 = df.dropDuplicates().count() print(validationCount, validationCount2) - testDataSpec = (generationSpec3 - .withConstraint(UniqueCombinations("*")) - ) + testDataSpec = generationSpec3.withConstraint(UniqueCombinations("*")) testDataDF = testDataSpec.build() @@ -270,38 +264,33 @@ def test_unique_combinations2(self, generationSpec3): print("rowCount", rowCount) assert rowCount == validationCount - @pytest.mark.parametrize("column, minValue, maxValue, strictFlag, expectedRows", - [ - ("id", 0, 100, True, 99), - ("id", 0, 100, False, 99), - ("id", 10, 20, True, 9), - ("id", 10, 20, False, 11), - ]) - def test_literal_range(self, column, minValue, maxValue, strictFlag, expectedRows, generationSpec2): \ - # pylint: disable=too-many-positional-arguments - - testDataSpec = (generationSpec2 - .withConstraints([SqlExpr("id < 100"), - SqlExpr("id > 0")]) - .withConstraint(LiteralRange(column, minValue, maxValue, strict=strictFlag)) - ) + @pytest.mark.parametrize( + "column, minValue, maxValue, strictFlag, expectedRows", + [ + ("id", 0, 100, True, 99), + ("id", 0, 100, False, 99), + ("id", 10, 20, True, 9), + ("id", 10, 20, False, 11), + ], + ) + def test_literal_range( + self, column, minValue, maxValue, strictFlag, expectedRows, generationSpec2 + ): # pylint: disable=too-many-positional-arguments + + testDataSpec = generationSpec2.withConstraints([SqlExpr("id < 100"), SqlExpr("id > 0")]).withConstraint( + LiteralRange(column, minValue, maxValue, strict=strictFlag) + ) testDataDF = testDataSpec.build() rowCount = testDataDF.count() assert rowCount == expectedRows - @pytest.mark.parametrize("strictSetting, expectedRows", - [ - (True, 99), - (False, 99) - ]) + @pytest.mark.parametrize("strictSetting, expectedRows", [(True, 99), (False, 99)]) def test_ranged_values(self, generationSpec2, strictSetting, expectedRows): - testDataSpec = (generationSpec2 - .withConstraints([SqlExpr("id < 100"), - SqlExpr("id > 0")]) - .withConstraint(RangedValues("code2", "code1", "code3", strict=strictSetting)) - ) + testDataSpec = generationSpec2.withConstraints([SqlExpr("id < 100"), SqlExpr("id > 0")]).withConstraint( + RangedValues("code2", "code1", "code3", strict=strictSetting) + ) testDataDF = testDataSpec.build() diff --git a/tests/test_data_generation_plugins.py b/tests/test_data_generation_plugins.py index 81b9592b..cfe838dc 100644 --- a/tests/test_data_generation_plugins.py +++ b/tests/test_data_generation_plugins.py @@ -20,11 +20,11 @@ def test_plugins(self, dataRows): def initPluginContext(context): context.prefix = "testing" - text_generator = (lambda context, v: context.prefix + str(v)) + text_generator = lambda context, v: context.prefix + str(v) - pluginDataspec = (dg.DataGenerator(spark, rows=dataRows, partitions=partitions_requested) - .withColumn("text", text=PyfuncText(text_generator, init=initPluginContext)) - ) + pluginDataspec = dg.DataGenerator(spark, rows=dataRows, partitions=partitions_requested).withColumn( + "text", text=PyfuncText(text_generator, init=initPluginContext) + ) dfPlugin = pluginDataspec.build() assert dfPlugin.count() == dataRows @@ -41,11 +41,11 @@ def test_plugin_clone(self): def initPluginContext(context): context.prefix = "testing" - text_generator = (lambda context, v: context.prefix + str(v)) + text_generator = lambda context, v: context.prefix + str(v) - pluginDataspec = (dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested) - .withColumn("text", text=PyfuncText(text_generator, init=initPluginContext)) - ) + pluginDataspec = dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested).withColumn( + "text", text=PyfuncText(text_generator, init=initPluginContext) + ) dfPlugin = pluginDataspec.build() dfCheck = dfPlugin.where("text like 'testing%'") @@ -64,7 +64,7 @@ def initPluginContext(context): assert new_count2 == data_rows def test_plugins_extended_syntax(self): - """ test property syntax""" + """test property syntax""" partitions_requested = 4 data_rows = 100 * 1000 @@ -80,9 +80,9 @@ def initPluginContext(context): CustomText = PyfuncTextFactory(name="CustomText").withInit(initPluginContext).withRootProperty("root") - pluginDataspec = (dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested) - .withColumn("text", text=CustomText("mkText")) - ) + pluginDataspec = dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested).withColumn( + "text", text=CustomText("mkText") + ) dfPlugin = pluginDataspec.build() assert dfPlugin.count() == data_rows @@ -93,7 +93,7 @@ def initPluginContext(context): assert new_count == data_rows def test_plugins_extended_syntax2(self): - """ test arg passing""" + """test arg passing""" partitions_requested = 4 data_rows = 100 * 1000 @@ -110,9 +110,9 @@ def initPluginContext(context): CustomText = PyfuncTextFactory(name="CustomText").withInit(initPluginContext).withRootProperty("root") - pluginDataspec = (dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested) - .withColumn("text", text=CustomText("mkText", isProperty=True)) - ) + pluginDataspec = dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested).withColumn( + "text", text=CustomText("mkText", isProperty=True) + ) dfPlugin = pluginDataspec.build() assert dfPlugin.count() == data_rows @@ -138,9 +138,9 @@ def initPluginContext(context): CustomText = PyfuncTextFactory(name="CustomText").withInit(initPluginContext).withRootProperty("root") - pluginDataspec = (dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested) - .withColumn("text", text=CustomText("mkText", extra="again")) - ) + pluginDataspec = dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested).withColumn( + "text", text=CustomText("mkText", extra="again") + ) dfPlugin = pluginDataspec.build() assert dfPlugin.count() == data_rows @@ -151,7 +151,7 @@ def initPluginContext(context): assert new_count == data_rows def test_plugins_extended_syntax4(self): - """ Test syntax extensions """ + """Test syntax extensions""" partitions_requested = 4 data_rows = 100 * 1000 @@ -178,9 +178,10 @@ def initPluginContext(context): assert x == "testing1again" def test_plugins_faker_integration(self): - """ test faker integration with mock objects""" + """test faker integration with mock objects""" import unittest.mock + shuffle_partitions_requested = 4 partitions_requested = 4 data_rows = 30 * 1000 @@ -196,19 +197,21 @@ def test_plugins_faker_integration(self): # partition parameters etc. spark.conf.set("spark.sql.shuffle.partitions", shuffle_partitions_requested) - fakerDataspec2 = (dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested) - .withColumn("customer_id", "int", uniqueValues=uniqueCustomers) - .withColumn("name", text=FakerText("__str__")) # use __str__ as it returns text - ) + fakerDataspec2 = ( + dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested) + .withColumn("customer_id", "int", uniqueValues=uniqueCustomers) + .withColumn("name", text=FakerText("__str__")) # use __str__ as it returns text + ) dfFaker2 = fakerDataspec2.build() output = dfFaker2.select("name").collect() for x in output: assert x["name"].startswith(" 1").count(), - "should not more than one value for line for same manufacturer ") + self.assertEqual( + 0, df2.where("c > 1").count(), "should not more than one value for line for same manufacturer " + ) def test_dependent_line2(self): self.assertEqual(0, self.dfTestData.where("line2 is null").count(), "should not have null values") @@ -139,8 +178,9 @@ def test_dependent_line2(self): df2 = spark.sql("select count(distinct line2) as c, manufacturer from test_data3 group by manufacturer") df2.show() - self.assertEqual(0, df2.where("c > 1").count(), - "should not more than one value for line for same manufacturer ") + self.assertEqual( + 0, df2.where("c > 1").count(), "should not more than one value for line for same manufacturer " + ) def test_dependent_country(self): self.assertEqual(0, self.dfTestData.where("country is null").count(), "should not have null values") @@ -149,16 +189,19 @@ def test_dependent_country(self): self.dfTestData.createOrReplaceTempView("test_data") df2 = spark.sql("select count(distinct country) as c from test_data group by device_id") - self.assertEqual(0, df2.where("c > 1").count(), - "should not more than one value for country for same device id ") + self.assertEqual( + 0, df2.where("c > 1").count(), "should not more than one value for country for same device id " + ) def test_spread_of_country_value1(self): # for given device id, should have the same country self.dfTestData.createOrReplaceTempView("test_data") - df2 = spark.sql(""" + df2 = spark.sql( + """ select count(distinct country) from test_data - """) + """ + ) results = df2.collect()[0][0] @@ -168,10 +211,12 @@ def test_spread_of_country_value2(self): # for given device id, should have the same country self.dfTestData.createOrReplaceTempView("test_data") - df2 = spark.sql(""" + df2 = spark.sql( + """ select count(distinct country2) from test_data - """) + """ + ) results = df2.collect()[0][0] @@ -180,33 +225,44 @@ def test_spread_of_country_value2(self): def test_format_dependent_data(self): ds_copy1 = self.testDataSpec.clone() - df_copy1 = (ds_copy1.withRowCount(1000) - .withColumn("device_id_2", StringType(), format='0x%013x', baseColumn="internal_device_id", - base_column_type="values") - .build()) + df_copy1 = ( + ds_copy1.withRowCount(1000) + .withColumn( + "device_id_2", + StringType(), + format='0x%013x', + baseColumn="internal_device_id", + base_column_type="values", + ) + .build() + ) df_copy1.show() # check data is not null and has unique values - count_distinct = (df_copy1.where("device_id_2 is not null") + count_distinct = ( + df_copy1.where("device_id_2 is not null") .agg(F.countDistinct('device_id_2').alias('count_d')) .collect()[0]['count_d'] - ) + ) self.assertGreaterEqual(count_distinct, 1) def test_format_dependent_data2(self): - """ Test without specifying the base column type""" + """Test without specifying the base column type""" ds_copy1 = self.testDataSpec.clone() - df_copy1 = (ds_copy1.withRowCount(1000) - .withColumn("device_id_2", StringType(), format='0x%013x', baseColumn="internal_device_id") - .build()) + df_copy1 = ( + ds_copy1.withRowCount(1000) + .withColumn("device_id_2", StringType(), format='0x%013x', baseColumn="internal_device_id") + .build() + ) df_copy1.show() # check data is not null and has unique values - count_distinct = (df_copy1.where("device_id_2 is not null") + count_distinct = ( + df_copy1.where("device_id_2 is not null") .agg(F.countDistinct('device_id_2').alias('count_d')) .collect()[0]['count_d'] - ) + ) self.assertGreaterEqual(count_distinct, 1) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index a668e3fc..861538c9 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -44,11 +44,14 @@ def get_observed_weights(cls, df, column, values): assert column is not None assert values is not None - observed_weights = (df.cube(column).count() - .withColumnRenamed(column, "value") - .withColumnRenamed("count", "rc") - .where("value is not null") - .collect()) + observed_weights = ( + df.cube(column) + .count() + .withColumnRenamed(column, "value") + .withColumnRenamed("count", "rc") + .where("value is not null") + .collect() + ) print(observed_weights) @@ -64,6 +67,7 @@ def test_valid_basic_distribution(self, basicDistributionInstance): def test_bad_distribution_inheritance(self, basicDistributionInstance): # define a bad derived class (due to lack of abstract methods) to test enforcement with pytest.raises(TypeError): + class MyDistribution(dist.DataDistribution): def dummyMethod(self): pass @@ -115,18 +119,28 @@ def test_simple_normal_distribution(self): .withIdOutput() # id column will be emitted in the output .withColumn("code1", "integer", minValue=1, maxValue=20, step=1) .withColumn("code4", "integer", minValue=1, maxValue=40, step=1, random=True, distribution="normal") - .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, - prefix='status', random=True, distribution="normal") - .withColumn("tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], - weights=desired_weights, - random=True) + .withColumn( + "sector_status_desc", + "string", + minValue=1, + maxValue=200, + step=1, + prefix='status', + random=True, + distribution="normal", + ) + .withColumn( + "tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], weights=desired_weights, random=True + ) ) df_normal_data = normal_data_generator.build().cache() - df_summary_general = df_normal_data.agg(F.min('code4').alias('min_c4'), - F.max('code4').alias('max_c4'), - F.avg('code4').alias('mean_c4'), - F.stddev('code4').alias('stddev_c4')) + df_summary_general = df_normal_data.agg( + F.min('code4').alias('min_c4'), + F.max('code4').alias('max_c4'), + F.avg('code4').alias('mean_c4'), + F.stddev('code4').alias('stddev_c4'), + ) df_summary_general.show() summary_data = df_summary_general.collect()[0] @@ -140,20 +154,31 @@ def test_normal_distribution(self): dg.DataGenerator(sparkSession=spark, rows=self.TESTDATA_ROWS, partitions=4) .withIdOutput() # id column will be emitted in the output .withColumn("code1", "integer", minValue=1, maxValue=20, step=1) - .withColumn("code4", "integer", minValue=1, maxValue=40, step=1, random=True, - distribution=dist.Normal(1.0, 1.0)) - .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, - prefix='status', random=True, distribution="normal") - .withColumn("tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], - weights=desired_weights, - random=True) + .withColumn( + "code4", "integer", minValue=1, maxValue=40, step=1, random=True, distribution=dist.Normal(1.0, 1.0) + ) + .withColumn( + "sector_status_desc", + "string", + minValue=1, + maxValue=200, + step=1, + prefix='status', + random=True, + distribution="normal", + ) + .withColumn( + "tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], weights=desired_weights, random=True + ) ) df_normal_data = normal_data_generator.build().cache() - df_summary_general = df_normal_data.agg(F.min('code4').alias('min_c4'), - F.max('code4').alias('max_c4'), - F.avg('code4').alias('mean_c4'), - F.stddev('code4').alias('stddev_c4')) + df_summary_general = df_normal_data.agg( + F.min('code4').alias('min_c4'), + F.max('code4').alias('max_c4'), + F.avg('code4').alias('mean_c4'), + F.stddev('code4').alias('stddev_c4'), + ) df_summary_general.show() summary_data = df_summary_general.collect()[0] @@ -167,20 +192,31 @@ def test_normal_distribution_seeded1(self): dg.DataGenerator(sparkSession=spark, rows=self.TESTDATA_ROWS, partitions=4, seed=42) .withIdOutput() # id column will be emitted in the output .withColumn("code1", "integer", minValue=1, maxValue=20, step=1) - .withColumn("code4", "integer", minValue=1, maxValue=40, step=1, random=True, - distribution=dist.Normal(1.0, 1.0)) - .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, - prefix='status', random=True, distribution="normal") - .withColumn("tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], - weights=desired_weights, - random=True) + .withColumn( + "code4", "integer", minValue=1, maxValue=40, step=1, random=True, distribution=dist.Normal(1.0, 1.0) + ) + .withColumn( + "sector_status_desc", + "string", + minValue=1, + maxValue=200, + step=1, + prefix='status', + random=True, + distribution="normal", + ) + .withColumn( + "tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], weights=desired_weights, random=True + ) ) df_normal_data = normal_data_generator.build().cache() - df_summary_general = df_normal_data.agg(F.min('code4').alias('min_c4'), - F.max('code4').alias('max_c4'), - F.avg('code4').alias('mean_c4'), - F.stddev('code4').alias('stddev_c4')) + df_summary_general = df_normal_data.agg( + F.min('code4').alias('min_c4'), + F.max('code4').alias('max_c4'), + F.avg('code4').alias('mean_c4'), + F.stddev('code4').alias('stddev_c4'), + ) df_summary_general.show() summary_data = df_summary_general.collect()[0] @@ -191,24 +227,36 @@ def test_normal_distribution_seeded1(self): def test_normal_distribution_seeded2(self): # will have implied column `id` for ordinal of row normal_data_generator = ( - dg.DataGenerator(sparkSession=spark, rows=self.TESTDATA_ROWS, partitions=4, - seed=42, seedMethod="hash_fieldname") + dg.DataGenerator( + sparkSession=spark, rows=self.TESTDATA_ROWS, partitions=4, seed=42, seedMethod="hash_fieldname" + ) .withIdOutput() # id column will be emitted in the output .withColumn("code1", "integer", minValue=1, maxValue=20, step=1) - .withColumn("code4", "integer", minValue=1, maxValue=40, step=1, random=True, - distribution=dist.Normal(1.0, 1.0)) - .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, - prefix='status', random=True, distribution="normal") - .withColumn("tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], - weights=desired_weights, - random=True) + .withColumn( + "code4", "integer", minValue=1, maxValue=40, step=1, random=True, distribution=dist.Normal(1.0, 1.0) + ) + .withColumn( + "sector_status_desc", + "string", + minValue=1, + maxValue=200, + step=1, + prefix='status', + random=True, + distribution="normal", + ) + .withColumn( + "tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], weights=desired_weights, random=True + ) ) df_normal_data = normal_data_generator.build().cache() - df_summary_general = df_normal_data.agg(F.min('code4').alias('min_c4'), - F.max('code4').alias('max_c4'), - F.avg('code4').alias('mean_c4'), - F.stddev('code4').alias('stddev_c4')) + df_summary_general = df_normal_data.agg( + F.min('code4').alias('min_c4'), + F.max('code4').alias('max_c4'), + F.avg('code4').alias('mean_c4'), + F.stddev('code4').alias('stddev_c4'), + ) df_summary_general.show() summary_data = df_summary_general.collect()[0] @@ -250,20 +298,31 @@ def test_gamma_distribution(self): dg.DataGenerator(sparkSession=spark, rows=self.TESTDATA_ROWS, partitions=4) .withIdOutput() # id column will be emitted in the output .withColumn("code1", "integer", minValue=1, maxValue=20, step=1) - .withColumn("code4", "integer", minValue=1, maxValue=40, step=1, random=True, - distribution=dist.Gamma(0.5, 0.5)) - .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, - prefix='status', random=True, distribution="normal") - .withColumn("tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], - weights=desired_weights, - random=True) + .withColumn( + "code4", "integer", minValue=1, maxValue=40, step=1, random=True, distribution=dist.Gamma(0.5, 0.5) + ) + .withColumn( + "sector_status_desc", + "string", + minValue=1, + maxValue=200, + step=1, + prefix='status', + random=True, + distribution="normal", + ) + .withColumn( + "tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], weights=desired_weights, random=True + ) ) df_gamma_data = gamma_data_generator.build().cache() - df_summary_general = df_gamma_data.agg(F.min('code4').alias('min_c4'), - F.max('code4').alias('max_c4'), - F.avg('code4').alias('mean_c4'), - F.stddev('code4').alias('stddev_c4')) + df_summary_general = df_gamma_data.agg( + F.min('code4').alias('min_c4'), + F.max('code4').alias('max_c4'), + F.avg('code4').alias('mean_c4'), + F.stddev('code4').alias('stddev_c4'), + ) df_summary_general.show() summary_data = df_summary_general.collect()[0] @@ -305,20 +364,31 @@ def test_beta_distribution(self): dg.DataGenerator(sparkSession=spark, rows=self.TESTDATA_ROWS, partitions=4) .withIdOutput() # id column will be emitted in the output .withColumn("code1", "integer", minValue=1, maxValue=20, step=1) - .withColumn("code4", "integer", minValue=1, maxValue=40, step=1, random=True, - distribution=dist.Beta(0.5, 0.5)) - .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, - prefix='status', random=True, distribution="normal") - .withColumn("tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], - weights=desired_weights, - random=True) + .withColumn( + "code4", "integer", minValue=1, maxValue=40, step=1, random=True, distribution=dist.Beta(0.5, 0.5) + ) + .withColumn( + "sector_status_desc", + "string", + minValue=1, + maxValue=200, + step=1, + prefix='status', + random=True, + distribution="normal", + ) + .withColumn( + "tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], weights=desired_weights, random=True + ) ) df_beta_data = beta_data_generator.build().cache() - df_summary_general = df_beta_data.agg(F.min('code4').alias('min_c4'), - F.max('code4').alias('max_c4'), - F.avg('code4').alias('mean_c4'), - F.stddev('code4').alias('stddev_c4')) + df_summary_general = df_beta_data.agg( + F.min('code4').alias('min_c4'), + F.max('code4').alias('max_c4'), + F.avg('code4').alias('mean_c4'), + F.stddev('code4').alias('stddev_c4'), + ) df_summary_general.show() summary_data = df_summary_general.collect()[0] @@ -359,20 +429,31 @@ def test_exponential_distribution(self): dg.DataGenerator(sparkSession=spark, rows=self.TESTDATA_ROWS, partitions=4) .withIdOutput() # id column will be emitted in the output .withColumn("code1", "integer", minValue=1, maxValue=20, step=1) - .withColumn("code4", "integer", minValue=1, maxValue=40, step=1, random=True, - distribution=dist.Exponential(0.5)) - .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, - prefix='status', random=True, distribution="normal") - .withColumn("tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], - weights=desired_weights, - random=True) + .withColumn( + "code4", "integer", minValue=1, maxValue=40, step=1, random=True, distribution=dist.Exponential(0.5) + ) + .withColumn( + "sector_status_desc", + "string", + minValue=1, + maxValue=200, + step=1, + prefix='status', + random=True, + distribution="normal", + ) + .withColumn( + "tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], weights=desired_weights, random=True + ) ) df_exponential_data = exponential_data_generator.build().cache() - df_summary_general = df_exponential_data.agg(F.min('code4').alias('min_c4'), - F.max('code4').alias('max_c4'), - F.avg('code4').alias('mean_c4'), - F.stddev('code4').alias('stddev_c4')) + df_summary_general = df_exponential_data.agg( + F.min('code4').alias('min_c4'), + F.max('code4').alias('max_c4'), + F.avg('code4').alias('mean_c4'), + F.stddev('code4').alias('stddev_c4'), + ) df_summary_general.show() summary_data = df_summary_general.collect()[0] diff --git a/tests/test_generation_from_data.py b/tests/test_generation_from_data.py index f5793b8a..19688ac6 100644 --- a/tests/test_generation_from_data.py +++ b/tests/test_generation_from_data.py @@ -27,31 +27,38 @@ def testLogger(self): @pytest.fixture def generation_spec(self): spec = ( - dg.DataGenerator(sparkSession=spark, name='test_generator', - rows=self.SMALL_ROW_COUNT, seedMethod='hash_fieldname') + dg.DataGenerator( + sparkSession=spark, name='test_generator', rows=self.SMALL_ROW_COUNT, seedMethod='hash_fieldname' + ) .withColumn('asin', 'string', template=r"adddd", random=True) .withColumn('brand', 'string', template=r"\w|\w \w \w|\w \w \w") .withColumn('helpful', 'array', expr="array(floor(rand()*100), floor(rand()*100))") - .withColumn('img', 'string', expr="concat('http://www.acme.com/downloads/images/', asin, '.png')", - baseColumn="asin") + .withColumn( + 'img', 'string', expr="concat('http://www.acme.com/downloads/images/', asin, '.png')", baseColumn="asin" + ) .withColumn('price', 'double', min=1.0, max=999.0, random=True, step=0.01) .withColumn('rating', 'double', values=[1.0, 2, 0, 3.0, 4.0, 5.0], random=True) .withColumn('review', 'string', text=dg.ILText((1, 3), (1, 4), (3, 8)), random=True) .withColumn('time', 'bigint', expr="now()", percentNulls=0.1) .withColumn('title', 'string', template=r"\w|\w \w \w|\w \w \w||\w \w \w \w", random=True) .withColumn('user', 'string', expr="hex(abs(hash(id)))") - .withColumn("event_ts", "timestamp", begin="2020-01-01 01:00:00", - end="2020-12-31 23:59:00", - interval="1 minute", random=True) - .withColumn("r_value", "float", expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=(2, 4), structType="array") + .withColumn( + "event_ts", + "timestamp", + begin="2020-01-01 01:00:00", + end="2020-12-31 23:59:00", + interval="1 minute", + random=True, + ) + .withColumn( + "r_value", "float", expr="floor(rand() * 350) * (86400 + 3600)", numColumns=(2, 4), structType="array" + ) .withColumn("tf_flag", "boolean", expr="id % 2 = 1") .withColumn("short_value", "short", max=32767, percentNulls=0.1) .withColumn("byte_value", "tinyint", max=127) .withColumn("decimal_value", "decimal(10,2)", max=1000000) .withColumn("date_value", "date", expr="current_date()", random=True) .withColumn("binary_value", "binary", expr="cast('spark' as binary)", random=True) - ) return spec @@ -111,4 +118,4 @@ def test_df_containing_summary(self): df = spark.range(10).withColumnRenamed("id", "summary") summary_df = dg.DataAnalyzer(sparkSession=spark, df=df).summarizeToDF() - assert summary_df.count() == 10 \ No newline at end of file + assert summary_df.count() == 10 diff --git a/tests/test_html_utils.py b/tests/test_html_utils.py index 92758bb2..b7fb88be 100644 --- a/tests/test_html_utils.py +++ b/tests/test_html_utils.py @@ -7,23 +7,32 @@ class TestHtmlUtils: - @pytest.mark.parametrize("content", - [""" + @pytest.mark.parametrize( + "content", + [ + """ for x in range(10): print(x) - """] - ) + """ + ], + ) def test_html_format_code(self, content): formattedContent = HtmlUtils.formatCodeAsHtml(content) assert formattedContent is not None assert content in formattedContent - @pytest.mark.parametrize("content, heading", - [(""" + @pytest.mark.parametrize( + "content, heading", + [ + ( + """ this is a test this is another one - """, "testing" - )]) + """, + "testing", + ) + ], + ) def test_html_format_content(self, content, heading): formattedContent = HtmlUtils.formatTextAsHtml(content, title=heading) diff --git a/tests/test_iltext_generation.py b/tests/test_iltext_generation.py index 12301b2a..72564425 100644 --- a/tests/test_iltext_generation.py +++ b/tests/test_iltext_generation.py @@ -7,23 +7,25 @@ import dbldatagen as dg from dbldatagen import ILText -schema = StructType([ - StructField("PK1", StringType(), True), - StructField("LAST_MODIFIED_UTC", TimestampType(), True), - StructField("date", DateType(), True), - StructField("str1", StringType(), True), - StructField("nint", IntegerType(), True), - StructField("nstr1", StringType(), True), - StructField("nstr2", StringType(), True), - StructField("nstr3", StringType(), True), - StructField("nstr4", StringType(), True), - StructField("nstr5", StringType(), True), - StructField("nstr6", StringType(), True), - StructField("email", StringType(), True), - StructField("ip_addr", StringType(), True), - StructField("phone", StringType(), True), - StructField("isDeleted", BooleanType(), True) -]) +schema = StructType( + [ + StructField("PK1", StringType(), True), + StructField("LAST_MODIFIED_UTC", TimestampType(), True), + StructField("date", DateType(), True), + StructField("str1", StringType(), True), + StructField("nint", IntegerType(), True), + StructField("nstr1", StringType(), True), + StructField("nstr2", StringType(), True), + StructField("nstr3", StringType(), True), + StructField("nstr4", StringType(), True), + StructField("nstr5", StringType(), True), + StructField("nstr6", StringType(), True), + StructField("email", StringType(), True), + StructField("ip_addr", StringType(), True), + StructField("phone", StringType(), True), + StructField("isDeleted", BooleanType(), True), + ] +) # add the following if using pandas udfs # .config("spark.sql.execution.arrow.maxRecordsPerBatch", "1000") \ @@ -47,26 +49,26 @@ def setUp(self): @classmethod def setUpClass(cls): print("setting up class ") - cls.testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=cls.row_count, - partitions=cls.partitions_requested) - .withSchema(schema) - .withIdOutput() - .withColumnSpec("date", percent_nulls=0.1) - .withColumnSpec("nint", percent_nulls=0.1, minValue=1, maxValue=9, step=2) - .withColumnSpec("nstr1", percent_nulls=0.1, minValue=1, maxValue=9, step=2) - .withColumnSpec("nstr2", percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, - format="%04.1f") - .withColumnSpec("nstr3", minValue=1.0, maxValue=9.0, step=2.0) - .withColumnSpec("nstr4", percent_nulls=0.1, minValue=1, maxValue=9, step=2, format="%04d") - .withColumnSpec("nstr5", percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, - random=True) - .withColumnSpec("nstr6", percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, - random=True, - format="%04f") - .withColumnSpec("email", template=r'\\w.\\w@\\w.com|\\w@\\w.co.u\\k') - .withColumnSpec("ip_addr", template=r'\\n.\\n.\\n.\\n') - .withColumnSpec("phone", template=r'(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd') - ) + cls.testDataSpec = ( + dg.DataGenerator( + sparkSession=spark, name="test_data_set1", rows=cls.row_count, partitions=cls.partitions_requested + ) + .withSchema(schema) + .withIdOutput() + .withColumnSpec("date", percent_nulls=0.1) + .withColumnSpec("nint", percent_nulls=0.1, minValue=1, maxValue=9, step=2) + .withColumnSpec("nstr1", percent_nulls=0.1, minValue=1, maxValue=9, step=2) + .withColumnSpec("nstr2", percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, format="%04.1f") + .withColumnSpec("nstr3", minValue=1.0, maxValue=9.0, step=2.0) + .withColumnSpec("nstr4", percent_nulls=0.1, minValue=1, maxValue=9, step=2, format="%04d") + .withColumnSpec("nstr5", percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True) + .withColumnSpec( + "nstr6", percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True, format="%04f" + ) + .withColumnSpec("email", template=r'\\w.\\w@\\w.com|\\w@\\w.co.u\\k') + .withColumnSpec("ip_addr", template=r'\\n.\\n.\\n.\\n') + .withColumnSpec("phone", template=r'(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd') + ) def test_basic_data_generation(self): results = self.testDataSpec.build() @@ -79,8 +81,9 @@ def test_phone_number_generation(self): results = self.testDataSpec.build() # check phone numbers - phone_values = [r[0] for r in - results.select(expr(r"regexp_replace(phone, '(\\d+)', 'num')")).distinct().collect()] + phone_values = [ + r[0] for r in results.select(expr(r"regexp_replace(phone, '(\\d+)', 'num')")).distinct().collect() + ] print(phone_values) self.assertSetEqual(set(phone_values), {'num num', '(num)-num-num', 'num(num) num-num'}) @@ -88,8 +91,9 @@ def test_email_generation(self): results = self.testDataSpec.build() # check email addresses - email_values = [r[0] for r in - results.select(expr(r"regexp_replace(email, '(\\w+)', 'word')")).distinct().collect()] + email_values = [ + r[0] for r in results.select(expr(r"regexp_replace(email, '(\\w+)', 'word')")).distinct().collect() + ] print(email_values) self.assertSetEqual(set(email_values), {'word@word.word.word', 'word.word@word.word'}) @@ -97,8 +101,9 @@ def test_ip_address_generation(self): results = self.testDataSpec.build() # check ip address - ip_address_values = [r[0] for r in - results.select(expr(r"regexp_replace(ip_addr, '(\\d+)', 'num')")).distinct().collect()] + ip_address_values = [ + r[0] for r in results.select(expr(r"regexp_replace(ip_addr, '(\\d+)', 'num')")).distinct().collect() + ] print(ip_address_values) self.assertSetEqual(set(ip_address_values), {'num.num.num.num'}) @@ -123,12 +128,14 @@ def test_basic_text_data_generation(self): def test_iltext1(self): print("test data spec 2") - testDataSpec2 = (dg.DataGenerator(sparkSession=spark, name="test_data_set2", rows=self.row_count, - partitions=self.partitions_requested) - .withSchema(schema) - .withIdOutput() - .withColumnSpec("nstr1", text=ILText(words=(2, 6))) - ) + testDataSpec2 = ( + dg.DataGenerator( + sparkSession=spark, name="test_data_set2", rows=self.row_count, partitions=self.partitions_requested + ) + .withSchema(schema) + .withIdOutput() + .withColumnSpec("nstr1", text=ILText(words=(2, 6))) + ) results2 = testDataSpec2.build().select("id", "nstr1").cache() results2.show() @@ -137,12 +144,14 @@ def test_iltext1(self): def test_iltext2(self): print("test data spec 2") - testDataSpec2 = (dg.DataGenerator(sparkSession=spark, name="test_data_set2", rows=self.row_count, - partitions=self.partitions_requested) - .withSchema(schema) - .withIdOutput() - .withColumnSpec("phone", text=ILText(paragraphs=(2, 6))) - ) + testDataSpec2 = ( + dg.DataGenerator( + sparkSession=spark, name="test_data_set2", rows=self.row_count, partitions=self.partitions_requested + ) + .withSchema(schema) + .withIdOutput() + .withColumnSpec("phone", text=ILText(paragraphs=(2, 6))) + ) testDataSpec2.build().select("id", "phone").show() @@ -150,12 +159,14 @@ def test_iltext2(self): def test_iltext3(self): print("test data spec 2") - testDataSpec2 = (dg.DataGenerator(sparkSession=spark, name="test_data_set2", rows=self.row_count, - partitions=self.partitions_requested) - .withSchema(schema) - .withIdOutput() - .withColumnSpec("phone", text=ILText(sentences=(2, 6))) - ) + testDataSpec2 = ( + dg.DataGenerator( + sparkSession=spark, name="test_data_set2", rows=self.row_count, partitions=self.partitions_requested + ) + .withSchema(schema) + .withIdOutput() + .withColumnSpec("phone", text=ILText(sentences=(2, 6))) + ) testDataSpec2.build().select("id", "phone").show() @@ -163,13 +174,19 @@ def test_iltext3(self): def test_iltext4a(self): print("test data spec 2") - testDataSpec2 = (dg.DataGenerator(sparkSession=spark, name="test_data_set2", rows=self.row_count, - partitions=self.partitions_requested, - usePandas=True, batchSize=300) - .withSchema(schema) - .withIdOutput() - .withColumnSpec("phone", text=ILText(paragraphs=(1, 4), sentences=(2, 6))) - ) + testDataSpec2 = ( + dg.DataGenerator( + sparkSession=spark, + name="test_data_set2", + rows=self.row_count, + partitions=self.partitions_requested, + usePandas=True, + batchSize=300, + ) + .withSchema(schema) + .withIdOutput() + .withColumnSpec("phone", text=ILText(paragraphs=(1, 4), sentences=(2, 6))) + ) testDataSpec2.build().select("id", "phone").show(20, truncate=False) @@ -177,12 +194,18 @@ def test_iltext4a(self): def test_iltext4b(self): print("test data spec 2") - testDataSpec2 = (dg.DataGenerator(sparkSession=spark, name="test_data_set2", rows=self.row_count, - partitions=self.partitions_requested, usePandas=False) - .withSchema(schema) - .withIdOutput() - .withColumnSpec("phone", text=ILText(paragraphs=(1, 4), sentences=(2, 6))) - ) + testDataSpec2 = ( + dg.DataGenerator( + sparkSession=spark, + name="test_data_set2", + rows=self.row_count, + partitions=self.partitions_requested, + usePandas=False, + ) + .withSchema(schema) + .withIdOutput() + .withColumnSpec("phone", text=ILText(paragraphs=(1, 4), sentences=(2, 6))) + ) testDataSpec2.build().select("id", "phone").show(20, truncate=False) @@ -190,12 +213,14 @@ def test_iltext4b(self): def test_iltext5(self): print("test data spec 2") - testDataSpec2 = (dg.DataGenerator(sparkSession=spark, name="test_data_set2", rows=self.row_count, - partitions=self.partitions_requested) - .withSchema(schema) - .withIdOutput() - .withColumnSpec("phone", text=ILText(paragraphs=(1, 4), words=(3, 12))) - ) + testDataSpec2 = ( + dg.DataGenerator( + sparkSession=spark, name="test_data_set2", rows=self.row_count, partitions=self.partitions_requested + ) + .withSchema(schema) + .withIdOutput() + .withColumnSpec("phone", text=ILText(paragraphs=(1, 4), words=(3, 12))) + ) testDataSpec2.build().select("id", "phone").show() @@ -203,12 +228,14 @@ def test_iltext5(self): def test_iltext6(self): print("test data spec 2") - testDataSpec2 = (dg.DataGenerator(sparkSession=spark, name="test_data_set2", rows=self.row_count, - partitions=self.partitions_requested) - .withSchema(schema) - .withIdOutput() - .withColumnSpec("phone", text=ILText(paragraphs=(1, 4), sentences=(2, 8), words=(3, 12))) - ) + testDataSpec2 = ( + dg.DataGenerator( + sparkSession=spark, name="test_data_set2", rows=self.row_count, partitions=self.partitions_requested + ) + .withSchema(schema) + .withIdOutput() + .withColumnSpec("phone", text=ILText(paragraphs=(1, 4), sentences=(2, 8), words=(3, 12))) + ) testDataSpec2.build().select("id", "phone").show() @@ -216,16 +243,19 @@ def test_iltext6(self): def test_iltext7(self): print("test data spec 2") - testDataSpec2 = (dg.DataGenerator(sparkSession=spark, name="test_data_set2", rows=self.row_count, - partitions=self.partitions_requested) - .withIdOutput() - .withColumn("sample_text", text=ILText(paragraphs=1, sentences=5, words=4)) - ) + testDataSpec2 = ( + dg.DataGenerator( + sparkSession=spark, name="test_data_set2", rows=self.row_count, partitions=self.partitions_requested + ) + .withIdOutput() + .withColumn("sample_text", text=ILText(paragraphs=1, sentences=5, words=4)) + ) testDataSpec2.build().select("id", "sample_text").show() # TODO: add validation statement + # run the tests # if __name__ == '__main__': # print("Trying to run tests") diff --git a/tests/test_large_schema.py b/tests/test_large_schema.py index d7fe1260..f11c6071 100644 --- a/tests/test_large_schema.py +++ b/tests/test_large_schema.py @@ -6,141 +6,143 @@ import dbldatagen as dg -schema = StructType([ - StructField("PK1", StringType(), True), - StructField("XYYZ_IDS", StringType(), True), - StructField("R_ID", IntegerType(), True), - StructField("CL_ID", StringType(), True), - StructField("INGEST_DATE", TimestampType(), True), - StructField("CMPY_ID", DecimalType(38, 0), True), - StructField("TXN_ID", DecimalType(38, 0), True), - StructField("SEQUENCE_NUMBER", DecimalType(38, 0), True), - StructField("DETAIL_ORDER", DecimalType(38, 0), True), - StructField("TX_T_ID", DecimalType(38, 0), True), - StructField("TXN_DATE", TimestampType(), True), - StructField("AN_ID", DecimalType(38, 0), True), - StructField("ANC_ID", DecimalType(38, 0), True), - StructField("ANV_ID", DecimalType(38, 0), True), - StructField("ANE_ID", DecimalType(38, 0), True), - StructField("AND_ID", DecimalType(38, 0), True), - StructField("APM_ID", DecimalType(38, 0), True), - StructField("ACL_ID", DecimalType(38, 0), True), - StructField("MEMO_TEXT", StringType(), True), - StructField("ITEM_ID", DecimalType(38, 0), True), - StructField("ITEM2_ID", DecimalType(38, 0), True), - StructField("V1_BASE", DecimalType(38, 9), True), - StructField("V1_YTD_AMT", DecimalType(38, 9), True), - StructField("V1_YTD_HOURS", DecimalType(38, 0), True), - StructField("ISTT", DecimalType(38, 9), True), - StructField("XXX_AMT", StringType(), True), - StructField("XXX_BASE", StringType(), True), - StructField("XXX_ISTT", StringType(), True), - StructField("HOURS", DecimalType(38, 0), True), - StructField("STATE", DecimalType(38, 0), True), - StructField("LSTATE", DecimalType(38, 0), True), - StructField("XXX_JURISDICTION_ID", DecimalType(38, 0), True), - StructField("XXY_JURISDICTION_ID", DecimalType(38, 0), True), - StructField("AS_OF_DATE", TimestampType(), True), - StructField("IS_PAYOUT", StringType(), True), - StructField("IS_PYRL_LIABILITY", StringType(), True), - StructField("IS_PYRL_SUMMARY", StringType(), True), - StructField("PYRL_LIABILITY_DATE", TimestampType(), True), - StructField("PYRL_LIAB_BEGIN_DATE", TimestampType(), True), - StructField("QTY", DecimalType(38, 9), True), - StructField("RATE", DecimalType(38, 9), True), - StructField("AMOUNT", DecimalType(38, 9), True), - StructField("SPERCENT", DecimalType(38, 9), True), - StructField("DOC_XREF", StringType(), True), - StructField("IS_A", StringType(), True), - StructField("IS_S", StringType(), True), - StructField("IS_CP", StringType(), True), - StructField("IS_VP", StringType(), True), - StructField("IS_B", StringType(), True), - StructField("IS_EX", StringType(), True), - StructField("IS_I", StringType(), True), - StructField("IS_CL", StringType(), True), - StructField("IS_DPD", StringType(), True), - StructField("IS_DPD2", StringType(), True), - StructField("DPD_ID", DecimalType(38, 0), True), - StructField("IS_NP", StringType(), True), - StructField("TAXABLE_TYPE", DecimalType(38, 0), True), - StructField("IS_ARP", StringType(), True), - StructField("IS_APP", StringType(), True), - StructField("BALANCE1", DecimalType(38, 9), True), - StructField("BALANCE2", DecimalType(38, 9), True), - StructField("IS_FLAG1", StringType(), True), - StructField("IS_FLAG2", StringType(), True), - StructField("STATEMENT_ID", DecimalType(38, 0), True), - StructField("INVOICE_ID", DecimalType(38, 0), True), - StructField("STATEMENT_DATE", TimestampType(), True), - StructField("INVOICE_DATE", TimestampType(), True), - StructField("DUE_DATE", TimestampType(), True), - StructField("EXAMPLE1_ID", DecimalType(38, 0), True), - StructField("EXAMPLE2_ID", DecimalType(38, 0), True), - StructField("IS_FLAG3", StringType(), True), - StructField("ANOTHER_ID", DecimalType(38, 0), True), - StructField("MARKUP", DecimalType(38, 9), True), - StructField("S_DATE", TimestampType(), True), - StructField("SD_TYPE", DecimalType(38, 0), True), - StructField("SOURCE_TXN_ID", DecimalType(38, 0), True), - StructField("SOURCE_TXN_SEQUENCE", DecimalType(38, 0), True), - StructField("PAID_DATE", TimestampType(), True), - StructField("OFX_TXN_ID", DecimalType(38, 0), True), - StructField("OFX_MATCH_FLAG", DecimalType(38, 0), True), - StructField("OLB_MATCH_MODE", DecimalType(38, 0), True), - StructField("OLB_MATCH_AMOUNT", DecimalType(38, 9), True), - StructField("OLB_RULE_ID", DecimalType(38, 0), True), - StructField("ETMMODE", DecimalType(38, 0), True), - StructField("DDA_ID", DecimalType(38, 0), True), - StructField("DDL_STATUS", DecimalType(38, 0), True), - StructField("ICFS", DecimalType(38, 0), True), - StructField("CREATE_DATE", TimestampType(), True), - StructField("CREATE_USER_ID", DecimalType(38, 0), True), - StructField("LAST_MODIFY_DATE", TimestampType(), True), - StructField("LAST_MODIFY_USER_ID", DecimalType(38, 0), True), - StructField("EDIT_SEQUENCE", DecimalType(38, 0), True), - StructField("ADDED_AUDIT_ID", DecimalType(38, 0), True), - StructField("AUDIT_ID", DecimalType(38, 0), True), - StructField("AUDIT_FLAG", StringType(), True), - StructField("EXCEPTION_FLAG", StringType(), True), - StructField("IS_PENALTY", StringType(), True), - StructField("IS_INTEREST", StringType(), True), - StructField("NET_AMOUNT", DecimalType(38, 9), True), - StructField("TAX_AMOUNT", DecimalType(38, 9), True), - StructField("TAX_CODE_ID", DecimalType(38, 0), True), - StructField("TAX_RATE_ID", DecimalType(38, 0), True), - StructField("CURRENCY_TYPE", DecimalType(38, 0), True), - StructField("EXCHANGE_RATE", DecimalType(38, 9), True), - StructField("HOME_AMOUNT", DecimalType(38, 9), True), - StructField("HOME_OPEN_BALANCE", DecimalType(38, 9), True), - StructField("IS_FOREX_GAIN_LOSS", StringType(), True), - StructField("SPECIAL_TAX_TYPE", DecimalType(38, 0), True), - StructField("SPECIAL_TAX_OPEN_BALANCE", DecimalType(38, 9), True), - StructField("TAX_OVERRIDE_DELTA_AMOUNT", DecimalType(38, 9), True), - StructField("INCLUSIVE_AMOUNT", DecimalType(38, 9), True), - StructField("CUSTOM_ACCOUNT_TAX_AMT", DecimalType(38, 9), True), - StructField("J_CODE_ID", DecimalType(38, 0), True), - StructField("DISCOUNT_ID", DecimalType(38, 0), True), - StructField("DISCOUNT_AMOUNT", DecimalType(38, 9), True), - StructField("TXN_DISCOUNT_AMOUNT", DecimalType(38, 9), True), - StructField("SUBTOTAL_AMOUNT", DecimalType(38, 9), True), - StructField("LINE_DETAIL_TYPE", DecimalType(38, 0), True), - StructField("W_RATE_ID", DecimalType(38, 0), True), - StructField("R_QUANTITY", DecimalType(38, 9), True), - StructField("R_AMOUNT", DecimalType(38, 9), True), - StructField("SRC_QTY_USED", DecimalType(38, 9), True), - StructField("SRC_AMT_USED", DecimalType(38, 9), True), - StructField("LM_CLOSED", StringType(), True), - StructField("CUSTOM_FIELD_VALUES", StringType(), True), - StructField("PROGRESS_TRACKING_TYPE", DecimalType(38, 0), True), - StructField("ITEM_RATE_TYPE", DecimalType(38, 0), True), - StructField("CUSTOM_FIELD_VALS", StringType(), True), - StructField("REGION_C_CODE", StringType(), True), - StructField("LAST_MODIFIED_UTC", TimestampType(), True), - StructField("date", DateType(), True), - StructField("yearMonth", StringType(), True), - StructField("isDeleted", BooleanType(), True) -]) +schema = StructType( + [ + StructField("PK1", StringType(), True), + StructField("XYYZ_IDS", StringType(), True), + StructField("R_ID", IntegerType(), True), + StructField("CL_ID", StringType(), True), + StructField("INGEST_DATE", TimestampType(), True), + StructField("CMPY_ID", DecimalType(38, 0), True), + StructField("TXN_ID", DecimalType(38, 0), True), + StructField("SEQUENCE_NUMBER", DecimalType(38, 0), True), + StructField("DETAIL_ORDER", DecimalType(38, 0), True), + StructField("TX_T_ID", DecimalType(38, 0), True), + StructField("TXN_DATE", TimestampType(), True), + StructField("AN_ID", DecimalType(38, 0), True), + StructField("ANC_ID", DecimalType(38, 0), True), + StructField("ANV_ID", DecimalType(38, 0), True), + StructField("ANE_ID", DecimalType(38, 0), True), + StructField("AND_ID", DecimalType(38, 0), True), + StructField("APM_ID", DecimalType(38, 0), True), + StructField("ACL_ID", DecimalType(38, 0), True), + StructField("MEMO_TEXT", StringType(), True), + StructField("ITEM_ID", DecimalType(38, 0), True), + StructField("ITEM2_ID", DecimalType(38, 0), True), + StructField("V1_BASE", DecimalType(38, 9), True), + StructField("V1_YTD_AMT", DecimalType(38, 9), True), + StructField("V1_YTD_HOURS", DecimalType(38, 0), True), + StructField("ISTT", DecimalType(38, 9), True), + StructField("XXX_AMT", StringType(), True), + StructField("XXX_BASE", StringType(), True), + StructField("XXX_ISTT", StringType(), True), + StructField("HOURS", DecimalType(38, 0), True), + StructField("STATE", DecimalType(38, 0), True), + StructField("LSTATE", DecimalType(38, 0), True), + StructField("XXX_JURISDICTION_ID", DecimalType(38, 0), True), + StructField("XXY_JURISDICTION_ID", DecimalType(38, 0), True), + StructField("AS_OF_DATE", TimestampType(), True), + StructField("IS_PAYOUT", StringType(), True), + StructField("IS_PYRL_LIABILITY", StringType(), True), + StructField("IS_PYRL_SUMMARY", StringType(), True), + StructField("PYRL_LIABILITY_DATE", TimestampType(), True), + StructField("PYRL_LIAB_BEGIN_DATE", TimestampType(), True), + StructField("QTY", DecimalType(38, 9), True), + StructField("RATE", DecimalType(38, 9), True), + StructField("AMOUNT", DecimalType(38, 9), True), + StructField("SPERCENT", DecimalType(38, 9), True), + StructField("DOC_XREF", StringType(), True), + StructField("IS_A", StringType(), True), + StructField("IS_S", StringType(), True), + StructField("IS_CP", StringType(), True), + StructField("IS_VP", StringType(), True), + StructField("IS_B", StringType(), True), + StructField("IS_EX", StringType(), True), + StructField("IS_I", StringType(), True), + StructField("IS_CL", StringType(), True), + StructField("IS_DPD", StringType(), True), + StructField("IS_DPD2", StringType(), True), + StructField("DPD_ID", DecimalType(38, 0), True), + StructField("IS_NP", StringType(), True), + StructField("TAXABLE_TYPE", DecimalType(38, 0), True), + StructField("IS_ARP", StringType(), True), + StructField("IS_APP", StringType(), True), + StructField("BALANCE1", DecimalType(38, 9), True), + StructField("BALANCE2", DecimalType(38, 9), True), + StructField("IS_FLAG1", StringType(), True), + StructField("IS_FLAG2", StringType(), True), + StructField("STATEMENT_ID", DecimalType(38, 0), True), + StructField("INVOICE_ID", DecimalType(38, 0), True), + StructField("STATEMENT_DATE", TimestampType(), True), + StructField("INVOICE_DATE", TimestampType(), True), + StructField("DUE_DATE", TimestampType(), True), + StructField("EXAMPLE1_ID", DecimalType(38, 0), True), + StructField("EXAMPLE2_ID", DecimalType(38, 0), True), + StructField("IS_FLAG3", StringType(), True), + StructField("ANOTHER_ID", DecimalType(38, 0), True), + StructField("MARKUP", DecimalType(38, 9), True), + StructField("S_DATE", TimestampType(), True), + StructField("SD_TYPE", DecimalType(38, 0), True), + StructField("SOURCE_TXN_ID", DecimalType(38, 0), True), + StructField("SOURCE_TXN_SEQUENCE", DecimalType(38, 0), True), + StructField("PAID_DATE", TimestampType(), True), + StructField("OFX_TXN_ID", DecimalType(38, 0), True), + StructField("OFX_MATCH_FLAG", DecimalType(38, 0), True), + StructField("OLB_MATCH_MODE", DecimalType(38, 0), True), + StructField("OLB_MATCH_AMOUNT", DecimalType(38, 9), True), + StructField("OLB_RULE_ID", DecimalType(38, 0), True), + StructField("ETMMODE", DecimalType(38, 0), True), + StructField("DDA_ID", DecimalType(38, 0), True), + StructField("DDL_STATUS", DecimalType(38, 0), True), + StructField("ICFS", DecimalType(38, 0), True), + StructField("CREATE_DATE", TimestampType(), True), + StructField("CREATE_USER_ID", DecimalType(38, 0), True), + StructField("LAST_MODIFY_DATE", TimestampType(), True), + StructField("LAST_MODIFY_USER_ID", DecimalType(38, 0), True), + StructField("EDIT_SEQUENCE", DecimalType(38, 0), True), + StructField("ADDED_AUDIT_ID", DecimalType(38, 0), True), + StructField("AUDIT_ID", DecimalType(38, 0), True), + StructField("AUDIT_FLAG", StringType(), True), + StructField("EXCEPTION_FLAG", StringType(), True), + StructField("IS_PENALTY", StringType(), True), + StructField("IS_INTEREST", StringType(), True), + StructField("NET_AMOUNT", DecimalType(38, 9), True), + StructField("TAX_AMOUNT", DecimalType(38, 9), True), + StructField("TAX_CODE_ID", DecimalType(38, 0), True), + StructField("TAX_RATE_ID", DecimalType(38, 0), True), + StructField("CURRENCY_TYPE", DecimalType(38, 0), True), + StructField("EXCHANGE_RATE", DecimalType(38, 9), True), + StructField("HOME_AMOUNT", DecimalType(38, 9), True), + StructField("HOME_OPEN_BALANCE", DecimalType(38, 9), True), + StructField("IS_FOREX_GAIN_LOSS", StringType(), True), + StructField("SPECIAL_TAX_TYPE", DecimalType(38, 0), True), + StructField("SPECIAL_TAX_OPEN_BALANCE", DecimalType(38, 9), True), + StructField("TAX_OVERRIDE_DELTA_AMOUNT", DecimalType(38, 9), True), + StructField("INCLUSIVE_AMOUNT", DecimalType(38, 9), True), + StructField("CUSTOM_ACCOUNT_TAX_AMT", DecimalType(38, 9), True), + StructField("J_CODE_ID", DecimalType(38, 0), True), + StructField("DISCOUNT_ID", DecimalType(38, 0), True), + StructField("DISCOUNT_AMOUNT", DecimalType(38, 9), True), + StructField("TXN_DISCOUNT_AMOUNT", DecimalType(38, 9), True), + StructField("SUBTOTAL_AMOUNT", DecimalType(38, 9), True), + StructField("LINE_DETAIL_TYPE", DecimalType(38, 0), True), + StructField("W_RATE_ID", DecimalType(38, 0), True), + StructField("R_QUANTITY", DecimalType(38, 9), True), + StructField("R_AMOUNT", DecimalType(38, 9), True), + StructField("SRC_QTY_USED", DecimalType(38, 9), True), + StructField("SRC_AMT_USED", DecimalType(38, 9), True), + StructField("LM_CLOSED", StringType(), True), + StructField("CUSTOM_FIELD_VALUES", StringType(), True), + StructField("PROGRESS_TRACKING_TYPE", DecimalType(38, 0), True), + StructField("ITEM_RATE_TYPE", DecimalType(38, 0), True), + StructField("CUSTOM_FIELD_VALS", StringType(), True), + StructField("REGION_C_CODE", StringType(), True), + StructField("LAST_MODIFIED_UTC", TimestampType(), True), + StructField("date", DateType(), True), + StructField("yearMonth", StringType(), True), + StructField("isDeleted", BooleanType(), True), + ] +) print("schema", schema) @@ -158,11 +160,11 @@ def setUp(self): @classmethod def setUpClass(cls): - cls.testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=cls.row_count, - partitions=4) - .withSchema(schema) - .withIdOutput() - ) + cls.testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=cls.row_count, partitions=4) + .withSchema(schema) + .withIdOutput() + ) print("Test generation plan") print("=============================") @@ -206,22 +208,24 @@ def test_large_clone(self): sale_values = ['RETAIL', 'ONLINE', 'WHOLESALE', 'RETURN'] sale_weights = [1, 5, 5, 1] - ds = (self.testDataSpec.clone().withRowCount(1000) - # test legacy argument `match_types` - .withColumnSpecs(patterns=".*_ID", match_types=StringType(), format="%010d", - minValue=10, maxValue=123, step=1) - # test revised argument `matchTypes` - .withColumnSpecs(patterns=".*_IDS", matchTypes=StringType(), format="%010d", - minValue=1, maxValue=100, step=1) - .withColumnSpec("R_ID", minValue=1, maxValue=100, step=1) - .withColumnSpec("XYYZ_IDS", minValue=1, maxValue=123, step=1, - format="%05d") - # .withColumnSpec("nstr4", percentNulls=0.1, minValue=1, maxValue=9, step=2, format="%04d") - # example of IS_SALE - .withColumnSpec("IS_S", values=sale_values, weights=sale_weights, random=True) - # .withColumnSpec("nstr4", percentNulls=0.1, minValue=1, maxValue=9, step=2, format="%04d") - - ) + ds = ( + self.testDataSpec.clone() + .withRowCount(1000) + # test legacy argument `match_types` + .withColumnSpecs( + patterns=".*_ID", match_types=StringType(), format="%010d", minValue=10, maxValue=123, step=1 + ) + # test revised argument `matchTypes` + .withColumnSpecs( + patterns=".*_IDS", matchTypes=StringType(), format="%010d", minValue=1, maxValue=100, step=1 + ) + .withColumnSpec("R_ID", minValue=1, maxValue=100, step=1) + .withColumnSpec("XYYZ_IDS", minValue=1, maxValue=123, step=1, format="%05d") + # .withColumnSpec("nstr4", percentNulls=0.1, minValue=1, maxValue=9, step=2, format="%04d") + # example of IS_SALE + .withColumnSpec("IS_S", values=sale_values, weights=sale_weights, random=True) + # .withColumnSpec("nstr4", percentNulls=0.1, minValue=1, maxValue=9, step=2, format="%04d") + ) df = ds.build() ds.explain() @@ -245,6 +249,7 @@ def test_large_clone(self): id_values2 = [r[0] for r in df.select("XYYZ_IDS").distinct().collect()] self.assertSetEqual(set(id_values2), set(xyyz_values)) + # run the tests # if __name__ == '__main__': # print("Trying to run tests") diff --git a/tests/test_logging.py b/tests/test_logging.py index 535cd605..11859155 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -11,6 +11,7 @@ @pytest.fixture(scope="class") def setupSpark(): import dbldatagen as dg + sparkSession = dg.SparkSingleton.getLocalInstance("unit tests") return sparkSession @@ -31,7 +32,7 @@ class TestLoggingOperation: row_count = SMALL_ROW_COUNT def setup_log_capture(self, caplog_object): - """ set up log capture fixture + """set up log capture fixture Sets up log capture fixture to only capture messages after setup and only capture warnings and errors @@ -121,16 +122,16 @@ def test_logging_operation2(self, setupSpark, caplog): from dbldatagen import DataGenerator - spec = (DataGenerator(sparkSession=spark, name="test_data_set1", rows=10000, seedMethod='hash_fieldname') - .withIdOutput() - .withColumn("r", "float", expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=10) - .withColumn("code1", "int", min=100, max=200) - .withColumn("code2", "int", min=0, max=10) - .withColumn("code3", "string", values=['a', 'b', 'c']) - .withColumn("code4", "string", values=['a', 'b', 'c'], random=True) - .withColumn("code5", "string", values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - ) + spec = ( + DataGenerator(sparkSession=spark, name="test_data_set1", rows=10000, seedMethod='hash_fieldname') + .withIdOutput() + .withColumn("r", "float", expr="floor(rand() * 350) * (86400 + 3600)", numColumns=10) + .withColumn("code1", "int", min=100, max=200) + .withColumn("code2", "int", min=0, max=10) + .withColumn("code3", "string", values=['a', 'b', 'c']) + .withColumn("code4", "string", values=['a', 'b', 'c'], random=True) + .withColumn("code5", "string", values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) df = spec.build() diff --git a/tests/test_multi_table.py b/tests/test_multi_table.py new file mode 100644 index 00000000..80d58c07 --- /dev/null +++ b/tests/test_multi_table.py @@ -0,0 +1,146 @@ +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from pyspark.sql.functions import col + +import dbldatagen as dg +from dbldatagen.multi_table_builder import MultiTableBuilder +from dbldatagen.relation import ForeignKeyRelation +from dbldatagen.utils import DataGenError + + +spark = dg.SparkSingleton.getLocalInstance("unit tests") + + +def _build_generator(name: str, column_names: list[str], rows: int = 10) -> dg.DataGenerator: + """ + Helper to create a ``DataGenerator`` with deterministic integer columns. + """ + generator = dg.DataGenerator(sparkSession=spark, name=name, rows=rows, partitions=1) + + for index, column_name in enumerate(column_names): + min_value = index * 100 + max_value = min_value + rows - 1 + generator = generator.withColumn(column_name, "int", minValue=min_value, maxValue=max_value) + + return generator + + +class TestMultiTableBuilder: + def test_single_data_generator_builds_dataset(self) -> None: + builder = MultiTableBuilder() + generator = _build_generator("single_gen", ["order_id", "order_value"], rows=12) + + builder.add_data_generator( + name="orders", + generator=generator, + columns=[col("order_id").alias("id"), "order_value"], + ) + + results = builder.build() + + assert set(results) == {"orders"} + orders_df = results["orders"] + assert orders_df.count() == 12 + assert orders_df.columns == ["id", "order_value"] + + def test_independent_data_generators_build_individually(self) -> None: + builder = MultiTableBuilder() + + generator_a = _build_generator("generator_a", ["a_id", "a_value"], rows=5) + generator_b = _build_generator("generator_b", ["b_id", "b_value"], rows=7) + + builder.add_data_generator("table_a", generator_a, columns=["a_id", "a_value"]) + builder.add_data_generator("table_b", generator_b, columns=["b_id", "b_value"]) + + results = builder.build() + + assert set(results.keys()) == {"table_a", "table_b"} + assert results["table_a"].count() == 5 + assert results["table_b"].count() == 7 + assert len(builder.data_generators) == 2 + + def test_foreign_key_relation_requires_shared_generator(self) -> None: + builder = MultiTableBuilder() + + parent_generator = _build_generator("parent_gen", ["parent_id", "parent_value"], rows=6) + child_generator = _build_generator("child_gen", ["child_id", "child_parent_id"], rows=6) + + builder.add_data_generator("parents", parent_generator, columns=["parent_id", "parent_value"]) + builder.add_data_generator( + "children", + child_generator, + columns=["child_id", "child_parent_id"], + ) + + builder.add_foreign_key_relation( + ForeignKeyRelation( + from_table="children", + from_column="child_parent_id", + to_table="parents", + to_column="parent_id", + ) + ) + + with pytest.raises(DataGenError): + builder.build(["children"]) + + def test_partial_relation_with_mismatched_generators_raises(self) -> None: + builder = MultiTableBuilder() + + orders_generator = _build_generator("orders_gen", ["order_id", "order_value"], rows=6) + line_items_generator = _build_generator( + "line_items_gen", + ["line_item_id", "order_id"], + rows=8, + ) + shipments_generator = _build_generator("shipments_gen", ["shipment_id"], rows=4) + + builder.add_data_generator("orders", orders_generator, columns=["order_id", "order_value"]) + builder.add_data_generator("line_items", line_items_generator, columns=["line_item_id", "order_id"]) + builder.add_data_generator("shipments", shipments_generator, columns=["shipment_id"]) + + builder.add_foreign_key_relation( + ForeignKeyRelation( + from_table="line_items", + from_column="order_id", + to_table="orders", + to_column="order_id", + ) + ) + + with pytest.raises(DataGenError): + builder.build() + + def test_transitive_relation_requires_shared_generator(self) -> None: + builder = MultiTableBuilder() + + generator_a = _build_generator("gen_a", ["a_id", "b_id"], rows=5) + generator_b = _build_generator("gen_b", ["b_id", "c_id"], rows=5) + generator_c = _build_generator("gen_c", ["c_id"], rows=5) + + builder.add_data_generator("table_a", generator_a, columns=["a_id", "b_id"]) + builder.add_data_generator("table_b", generator_b, columns=["b_id", "c_id"]) + builder.add_data_generator("table_c", generator_c, columns=["c_id"]) + + builder.add_foreign_key_relation( + ForeignKeyRelation( + from_table="table_a", + from_column="b_id", + to_table="table_b", + to_column="b_id", + ) + ) + builder.add_foreign_key_relation( + ForeignKeyRelation( + from_table="table_b", + from_column="c_id", + to_table="table_c", + to_column="c_id", + ) + ) + + with pytest.raises(DataGenError): + builder.get_dataset("table_a") diff --git a/tests/test_options.py b/tests/test_options.py index 78a595f8..2035678a 100644 --- a/tests/test_options.py +++ b/tests/test_options.py @@ -40,7 +40,6 @@ def test_basic(self): .withColumn("code3", "integer", minValue=1, maxValue=20, step=1, random=True) .withColumn("code4", "integer", minValue=1, maxValue=20, step=1, random=True) # base column specifies dependent column - .withColumn("site_cd", "string", prefix='site', baseColumn='code1') .withColumn("device_status", "string", minValue=1, maxValue=200, step=1, prefix='status', random=True) .withColumn("tech", "string", values=["GSM", "UMTS", "LTE", "UNKNOWN"], random=True) @@ -94,17 +93,14 @@ def test_aliased_options(self): .withColumn("code2", "integer", minValue=1, maxValue=20, step=1, distribution="normal") .withColumn("code3", "integer", min=1, max=20, step=1, base_column="code1") .withColumn("code4", "integer", min=1, max=20, step=1, baseColumn="code1") - # implicit allows column definition to be overridden - used by system when initializing from schema .withColumn("code5", "integer", min=1, max=20, step=1, baseColumn="code1", implicit=True) - .withColumn("code5", "integer", min=1, max=20, step=1, baseColumn="code4", random_seed=45) .withColumn("code6", "integer", minValue=1, maxValue=20, step=1, omit=True) .withColumn("code7", "integer", min=1, max=20, step=1, baseColumn="code6") .withColumn("code2", "integer", minValue=1, maxValue=20, step=1, distribution="normal") .withColumn("site_cd1", "string", prefix='site', baseColumn='code1', text_separator="") .withColumn("site_cd2", "string", prefix='site', baseColumn='code1', textSeparator="-") - ) colSpec1 = testdata_generator.getColumnSpec("site_cd1") @@ -144,11 +140,12 @@ def test_aliased_options2(self): testdata_generator = ( dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=10000, partitions=4) .withColumn("code1", "integer", min=1, max=20, step=1) - .withColumn("site_cd1", "string", prefix='site', baseColumn='code1', - random_seed_method=dg.RANDOM_SEED_FIXED) - .withColumn("site_cd2", "string", prefix='site', baseColumn='code1', - randomSeedMethod=dg.RANDOM_SEED_HASH_FIELD_NAME) - + .withColumn( + "site_cd1", "string", prefix='site', baseColumn='code1', random_seed_method=dg.RANDOM_SEED_FIXED + ) + .withColumn( + "site_cd2", "string", prefix='site', baseColumn='code1', randomSeedMethod=dg.RANDOM_SEED_HASH_FIELD_NAME + ) ) colSpec1 = testdata_generator.getColumnSpec("site_cd1") @@ -177,7 +174,6 @@ def test_random1(self): .withColumn("code3", "integer", minValue=1, maxValue=20, step=1, random=True) .withColumn("code4", "integer", minValue=1, maxValue=20, step=1, random=True) # base column specifies dependent column - .withColumn("site_cd", "string", prefix='site', baseColumn='code1') .withColumn("device_status", "string", minValue=1, maxValue=200, step=1, prefix='status', random=True) .withColumn("tech", "string", values=["GSM", "UMTS", "LTE", "UNKNOWN"], random=True) @@ -200,7 +196,6 @@ def test_random2(self): .withColumn("code3", "integer", minValue=1, maxValue=20, step=1, random=True) .withColumn("code4", "integer", minValue=1, maxValue=20, step=1) # base column specifies dependent column - .withColumn("site_cd", "string", values=["one", "two", "three"], weights=[1, 2, 3], baseColumn='code1') .withColumn("device_status", "string", minValue=1, maxValue=200, step=1, prefix='status') .withColumn("tech", "string", values=["GSM", "UMTS", "LTE", "UNKNOWN"]) @@ -245,7 +240,6 @@ def test_random_multiple_columns(self): .withColumn("code2", "integer", minValue=1, maxValue=20, step=1, numColumns=5) .withColumn("code3", "integer", minValue=1, maxValue=20, step=1, numFeatures=(3, 5), structType="array") .withColumn("code4", "integer", minValue=1, maxValue=20, step=1) - ) df = ds.build() @@ -255,13 +249,7 @@ def test_random_multiple_columns(self): df.show() - @pytest.mark.parametrize("numFeaturesSupplied", - [(3, 5, 3), - (3.4, 3), - ("3", "5"), - "3", - (5, 3) - ]) + @pytest.mark.parametrize("numFeaturesSupplied", [(3, 5, 3), (3.4, 3), ("3", "5"), "3", (5, 3)]) def test_random_multiple_columns_bad(self, numFeaturesSupplied): # will have implied column `id` for ordinal of row @@ -270,10 +258,16 @@ def test_random_multiple_columns_bad(self, numFeaturesSupplied): dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=1000, partitions=4, random=True) .withColumn("code1", "integer", min=1, max=20, step=1) .withColumn("code2", "integer", minValue=1, maxValue=20, step=1, numColumns=5) - .withColumn("code3", "integer", minValue=1, maxValue=20, step=1, numFeatures=numFeaturesSupplied, - structType="array") + .withColumn( + "code3", + "integer", + minValue=1, + maxValue=20, + step=1, + numFeatures=numFeaturesSupplied, + structType="array", + ) .withColumn("code4", "integer", minValue=1, maxValue=20, step=1) - ) df = ds.build() @@ -289,7 +283,6 @@ def test_random_multiple_columns_warning(self, errorsAndWarningsLog): .withColumn("code2", "integer", minValue=1, maxValue=20, step=1, numColumns=5) .withColumn("code3", "integer", minValue=1, maxValue=20, step=1, numColumns=(3, 5)) .withColumn("code4", "integer", minValue=1, maxValue=20, step=1, numFeatures=(3, 5)) - ) df = ds.build() @@ -299,20 +292,21 @@ def test_random_multiple_columns_warning(self, errorsAndWarningsLog): assert msgs > 0 - @pytest.mark.parametrize("numFeaturesSupplied", - [3, - (2, 4), - 0, - (0, 3) - ]) + @pytest.mark.parametrize("numFeaturesSupplied", [3, (2, 4), 0, (0, 3)]) def test_multiple_columns_email(self, numFeaturesSupplied): # will have implied column `id` for ordinal of row ds = ( dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=1000, partitions=4, random=True) .withColumn("name", "string", percentNulls=0.01, template=r'\\w \\w|\\w A. \\w|test') - .withColumn("emails", "string", template=r'\\w.\\w@\\w.com', random=True, - numFeatures=numFeaturesSupplied, structType="array") + .withColumn( + "emails", + "string", + template=r'\\w.\\w@\\w.com', + random=True, + numFeatures=numFeaturesSupplied, + structType="array", + ) ) df = ds.build() @@ -330,20 +324,21 @@ def test_multiple_columns_email(self, numFeaturesSupplied): assert min(set_lengths) == min_lengths assert max(set_lengths) == max_lengths - @pytest.mark.parametrize("numFeaturesSupplied", - [3, - (2, 4), - 0, - (0, 3) - ]) + @pytest.mark.parametrize("numFeaturesSupplied", [3, (2, 4), 0, (0, 3)]) def test_multi_email_random(self, numFeaturesSupplied): # will have implied column `id` for ordinal of row ds = ( dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=1000, partitions=4, random=True) .withColumn("name", "string", percentNulls=0.01, template=r'\\w \\w|\\w A. \\w|test') - .withColumn("emails", "string", template=r'\\w.\\w@\\w.com', random=True, - numFeatures=numFeaturesSupplied, structType="array") + .withColumn( + "emails", + "string", + template=r'\\w.\\w@\\w.com', + random=True, + numFeatures=numFeaturesSupplied, + structType="array", + ) ) df = ds.build() diff --git a/tests/test_output.py b/tests/test_output.py index 3d41c50b..fc6c2ddd 100644 --- a/tests/test_output.py +++ b/tests/test_output.py @@ -44,12 +44,11 @@ def test_build_output_data_batch(self, get_output_directories, seed_column_name, rows=100, partitions=4, seedMethod='hash_fieldname', - seedColumnName=seed_column_name + seedColumnName=seed_column_name, ) gen = ( - gen - .withIdOutput() + gen.withIdOutput() .withColumn("code1", IntegerType(), minValue=100, maxValue=200) .withColumn("code2", IntegerType(), minValue=0, maxValue=10) .withColumn("code3", StringType(), values=['a', 'b', 'c']) @@ -79,12 +78,11 @@ def test_build_output_data_streaming(self, get_output_directories, seed_column_n rows=100, partitions=4, seedMethod='hash_fieldname', - seedColumnName=seed_column_name + seedColumnName=seed_column_name, ) gen = ( - gen - .withIdOutput() + gen.withIdOutput() .withColumn("code1", IntegerType(), minValue=100, maxValue=200) .withColumn("code2", IntegerType(), minValue=0, maxValue=10) .withColumn("code3", StringType(), values=['a', 'b', 'c']) @@ -97,7 +95,7 @@ def test_build_output_data_streaming(self, get_output_directories, seed_column_n output_mode="append", format=table_format, options={"mergeSchema": "true", "checkpointLocation": f"{data_dir}/{checkpoint_dir}"}, - trigger={"processingTime": "1 SECOND"} + trigger={"processingTime": "1 SECOND"}, ) query = gen.saveAsDataset(output_dataset, with_streaming=True) diff --git a/tests/test_pandas_integration.py b/tests/test_pandas_integration.py index 413a8360..f811df5e 100644 --- a/tests/test_pandas_integration.py +++ b/tests/test_pandas_integration.py @@ -7,30 +7,34 @@ from pyspark.sql.types import BooleanType, DateType from pyspark.sql.types import StructType, StructField, StringType, TimestampType -schema = StructType([ - StructField("PK1", StringType(), True), - StructField("LAST_MODIFIED_UTC", TimestampType(), True), - StructField("date", DateType(), True), - StructField("str1", StringType(), True), - StructField("email", StringType(), True), - StructField("ip_addr", StringType(), True), - StructField("phone", StringType(), True), - StructField("isDeleted", BooleanType(), True) -]) +schema = StructType( + [ + StructField("PK1", StringType(), True), + StructField("LAST_MODIFIED_UTC", TimestampType(), True), + StructField("date", DateType(), True), + StructField("str1", StringType(), True), + StructField("email", StringType(), True), + StructField("ip_addr", StringType(), True), + StructField("phone", StringType(), True), + StructField("isDeleted", BooleanType(), True), + ] +) print("schema", schema) -spark = SparkSession.builder \ - .master("local[4]") \ - .appName("spark unit tests") \ - .config("spark.sql.warehouse.dir", "/tmp/spark-warehouse") \ - .config("spark.sql.execution.arrow.maxRecordsPerBatch", "1000") \ +spark = ( + SparkSession.builder.master("local[4]") + .appName("spark unit tests") + .config("spark.sql.warehouse.dir", "/tmp/spark-warehouse") + .config("spark.sql.execution.arrow.maxRecordsPerBatch", "1000") .getOrCreate() +) # Test manipulation and generation of test data for a large schema class TestPandasIntegration(unittest.TestCase): - """ Test that build environment is setup correctly for pandas and numpy integration""" + """Test that build environment is setup correctly for pandas and numpy integration""" + testDataSpec = None dfTestData = None row_count = 100000 @@ -75,11 +79,12 @@ def test_pandas(self): @unittest.skip("not yet debugged") def test_pandas_udf(self): utest_pandas = pandas_udf(self.pandas_udf_example, returnType=StringType()).asNondeterministic() - df = (spark.range(1000000) - .withColumn("x", expr("cast(id as string)")) - .withColumn("y", expr("cast(id as string)")) - .withColumn("z", utest_pandas(col("x"), col("y"))) - ) + df = ( + spark.range(1000000) + .withColumn("x", expr("cast(id as string)")) + .withColumn("y", expr("cast(id as string)")) + .withColumn("z", utest_pandas(col("x"), col("y"))) + ) df.show() @@ -98,4 +103,5 @@ def test_numpy2(self): self.assertGreater(np.sum(data), 0) -# \ No newline at end of file + +# diff --git a/tests/test_quick_tests.py b/tests/test_quick_tests.py index 4d1eab96..1a04d4bb 100644 --- a/tests/test_quick_tests.py +++ b/tests/test_quick_tests.py @@ -2,8 +2,17 @@ import pytest from pyspark.sql.types import ( - StructType, StructField, IntegerType, StringType, FloatType, DateType, DecimalType, DoubleType, ByteType, - ShortType, LongType + StructType, + StructField, + IntegerType, + StringType, + FloatType, + DateType, + DecimalType, + DoubleType, + ByteType, + ShortType, + LongType, ) @@ -11,16 +20,17 @@ from dbldatagen import DataGenerator from dbldatagen import NRange, DateRange -schema = StructType([ - StructField("site_id", IntegerType(), True), - StructField("site_cd", StringType(), True), - StructField("c", StringType(), True), - StructField("c1", StringType(), True), - StructField("state1", StringType(), True), - StructField("state2", StringType(), True), - StructField("sector_technology_desc", StringType(), True), - -]) +schema = StructType( + [ + StructField("site_id", IntegerType(), True), + StructField("site_cd", StringType(), True), + StructField("c", StringType(), True), + StructField("c1", StringType(), True), + StructField("state1", StringType(), True), + StructField("state2", StringType(), True), + StructField("sector_technology_desc", StringType(), True), + ] +) interval = timedelta(seconds=10) start = datetime(2018, 10, 1, 6, 0, 0) @@ -30,13 +40,14 @@ src_start = datetime(2017, 10, 1, 0, 0, 0) src_end = datetime(2018, 10, 1, 6, 0, 0) -schema = StructType([ - StructField("site_id", IntegerType(), True), - StructField("site_cd", StringType(), True), - StructField("c", StringType(), True), - StructField("c1", StringType(), True) - -]) +schema = StructType( + [ + StructField("site_id", IntegerType(), True), + StructField("site_cd", StringType(), True), + StructField("c", StringType(), True), + StructField("c1", StringType(), True), + ] +) # build spark session spark = dg.SparkSingleton.getLocalInstance("quick tests") @@ -49,10 +60,9 @@ class TestQuickTests: """ def test_analyzer(self): - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) - .withIdOutput() - .build() - ) + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4).withIdOutput().build() + ) print("schema", testDataDF.schema) testDataDF.printSchema() @@ -69,18 +79,17 @@ def test_analyzer(self): print("Summary;", results) def test_complex_datagen(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, - partitions=4) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)") - .withColumn("code1a", IntegerType(), unique_values=100) - .withColumn("code1b", IntegerType(), min=1, max=200) - .withColumn("code2", IntegerType(), maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)") + .withColumn("code1a", IntegerType(), unique_values=100) + .withColumn("code1b", IntegerType(), min=1, max=200) + .withColumn("code2", IntegerType(), maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) testDataDF2 = testDataSpec.build() @@ -95,13 +104,15 @@ def test_complex_datagen(self): # testDataDF2.show() testDataDF2.createOrReplaceTempView("testdata") - df_stats = spark.sql("""select min(code1a) as min1a, + df_stats = spark.sql( + """select min(code1a) as min1a, max(code1a) as max1a, min(code1b) as min1b, max(code1b) as max1b, min(code2) as min2, max(code2) as max2 - from testdata""") + from testdata""" + ) stats = df_stats.collect()[0] print("stats", stats) @@ -120,24 +131,28 @@ def test_generate_name(self): assert n1 != n2, "Names should be different" def test_column_specifications(self): - tgen = (DataGenerator(sparkSession=spark, name="test_data_set", rows=1000000, partitions=8) - .withSchema(schema) - .withColumn("sector_status_desc", StringType(), minValue=1, maxValue=200, step=1, - prefix='status', random=True) - .withColumn("s", StringType(), minValue=1, maxValue=200, step=1, prefix='status', - random=True, omit=True)) + tgen = ( + DataGenerator(sparkSession=spark, name="test_data_set", rows=1000000, partitions=8) + .withSchema(schema) + .withColumn( + "sector_status_desc", StringType(), minValue=1, maxValue=200, step=1, prefix='status', random=True + ) + .withColumn("s", StringType(), minValue=1, maxValue=200, step=1, prefix='status', random=True, omit=True) + ) print("test_column_specifications") expectedColumns = set((["id", "site_id", "site_cd", "c", "c1", "sector_status_desc", "s"])) assert expectedColumns == set(([x.name for x in tgen._allColumnSpecs])) def test_inferred_columns(self): - tgen = (DataGenerator(sparkSession=spark, name="test_data_set", rows=1000000, partitions=8) - .withSchema(schema) - .withColumn("sector_status_desc", StringType(), minValue=1, maxValue=200, step=1, - prefix='status', random=True) - .withColumn("s", StringType(), minValue=1, maxValue=200, step=1, prefix='status', - random=True, omit=True)) + tgen = ( + DataGenerator(sparkSession=spark, name="test_data_set", rows=1000000, partitions=8) + .withSchema(schema) + .withColumn( + "sector_status_desc", StringType(), minValue=1, maxValue=200, step=1, prefix='status', random=True + ) + .withColumn("s", StringType(), minValue=1, maxValue=200, step=1, prefix='status', random=True, omit=True) + ) print("test_inferred_columns") expectedColumns = set((["id", "site_id", "site_cd", "c", "c1", "sector_status_desc", "s"])) @@ -145,12 +160,14 @@ def test_inferred_columns(self): assert expectedColumns == set((tgen.getInferredColumnNames())) def test_output_columns(self): - tgen = (DataGenerator(sparkSession=spark, name="test_data_set", rows=1000000, partitions=8) - .withSchema(schema) - .withColumn("sector_status_desc", StringType(), minValue=1, maxValue=200, step=1, - prefix='status', random=True) - .withColumn("s", StringType(), minValue=1, maxValue=200, step=1, prefix='status', - random=True, omit=True)) + tgen = ( + DataGenerator(sparkSession=spark, name="test_data_set", rows=1000000, partitions=8) + .withSchema(schema) + .withColumn( + "sector_status_desc", StringType(), minValue=1, maxValue=200, step=1, prefix='status', random=True + ) + .withColumn("s", StringType(), minValue=1, maxValue=200, step=1, prefix='status', random=True, omit=True) + ) print("test_output_columns") expectedColumns = set((["site_id", "site_cd", "c", "c1", "sector_status_desc"])) @@ -158,12 +175,14 @@ def test_output_columns(self): assert expectedColumns == set((tgen.getOutputColumnNames())) def test_with_column_spec_for_missing_column(self): - tgen = (DataGenerator(sparkSession=spark, name="test_data_set", rows=1000000, partitions=8) - .withSchema(schema) - .withColumn("sector_status_desc", StringType(), minValue=1, maxValue=200, step=1, - prefix='status', random=True) - .withColumn("s", StringType(), minValue=1, maxValue=200, step=1, prefix='status', - random=True, omit=True)) + tgen = ( + DataGenerator(sparkSession=spark, name="test_data_set", rows=1000000, partitions=8) + .withSchema(schema) + .withColumn( + "sector_status_desc", StringType(), minValue=1, maxValue=200, step=1, prefix='status', random=True + ) + .withColumn("s", StringType(), minValue=1, maxValue=200, step=1, prefix='status', random=True, omit=True) + ) print("test_with_column_spec_for_missing_column") with pytest.raises(Exception): @@ -171,12 +190,14 @@ def test_with_column_spec_for_missing_column(self): assert t2 is not None, "expecting t2 to be a new generator spec" def test_with_column_spec_for_duplicate_column(self): - tgen = (DataGenerator(sparkSession=spark, name="test_data_set", rows=1000000, partitions=8) - .withSchema(schema) - .withColumn("sector_status_desc", StringType(), minValue=1, maxValue=200, step=1, - prefix='status', random=True) - .withColumn("s", StringType(), minValue=1, maxValue=200, step=1, prefix='status', - random=True, omit=True)) + tgen = ( + DataGenerator(sparkSession=spark, name="test_data_set", rows=1000000, partitions=8) + .withSchema(schema) + .withColumn( + "sector_status_desc", StringType(), minValue=1, maxValue=200, step=1, prefix='status', random=True + ) + .withColumn("s", StringType(), minValue=1, maxValue=200, step=1, prefix='status', random=True, omit=True) + ) print("test_with_column_spec_for_duplicate_column") with pytest.raises(Exception): @@ -185,24 +206,28 @@ def test_with_column_spec_for_duplicate_column(self): assert t3 is not None, "expecting t3 to be a new generator spec" def test_with_column_spec_for_duplicate_column2(self): - tgen = (DataGenerator(sparkSession=spark, name="test_data_set", rows=1000000, partitions=8) - .withSchema(schema) - .withColumn("sector_status_desc", StringType(), minValue=1, maxValue=200, step=1, - prefix='status', random=True) - .withColumn("s", StringType(), minValue=1, maxValue=200, step=1, prefix='status', - random=True, omit=True)) + tgen = ( + DataGenerator(sparkSession=spark, name="test_data_set", rows=1000000, partitions=8) + .withSchema(schema) + .withColumn( + "sector_status_desc", StringType(), minValue=1, maxValue=200, step=1, prefix='status', random=True + ) + .withColumn("s", StringType(), minValue=1, maxValue=200, step=1, prefix='status', random=True, omit=True) + ) print("test_with_column_spec_for_duplicate_column2") t2 = tgen.withColumn("site_id", "string", minValue=1, maxValue=200, step=1, random=True) assert t2 is not None, "expecting t2 to be a new generator spec" def test_with_column_spec_for_id_column(self): - tgen = (DataGenerator(sparkSession=spark, name="test_data_set", rows=1000000, partitions=8) - .withSchema(schema) - .withColumn("sector_status_desc", StringType(), minValue=1, maxValue=200, step=1, - prefix='status', random=True) - .withColumn("s", StringType(), minValue=1, maxValue=200, step=1, prefix='status', - random=True, omit=True)) + tgen = ( + DataGenerator(sparkSession=spark, name="test_data_set", rows=1000000, partitions=8) + .withSchema(schema) + .withColumn( + "sector_status_desc", StringType(), minValue=1, maxValue=200, step=1, prefix='status', random=True + ) + .withColumn("s", StringType(), minValue=1, maxValue=200, step=1, prefix='status', random=True, omit=True) + ) print("test_with_column_spec_for_id_column") t2 = tgen.withIdOutput() @@ -212,40 +237,40 @@ def test_with_column_spec_for_id_column(self): assert expectedColumns == set((t2.getOutputColumnNames())) def test_basic_ranges_with_view(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="ranged_data", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("code1a", IntegerType(), unique_values=100) - .withColumn("code1b", IntegerType(), minValue=1, maxValue=100) - .withColumn("code1c", IntegerType(), minValue=1, maxValue=200, unique_values=100) - .withColumn("code1d", IntegerType(), minValue=1, maxValue=200, step=3, unique_values=50) - .withColumn("code2", IntegerType(), maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="ranged_data", rows=100000, partitions=4) + .withIdOutput() + .withColumn("code1a", IntegerType(), unique_values=100) + .withColumn("code1b", IntegerType(), minValue=1, maxValue=100) + .withColumn("code1c", IntegerType(), minValue=1, maxValue=200, unique_values=100) + .withColumn("code1d", IntegerType(), minValue=1, maxValue=200, step=3, unique_values=50) + .withColumn("code2", IntegerType(), maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) testDataSpec.build(withTempView=True).cache() # we refer to the view generated above - result = spark.sql("""select count(distinct code1a), + result = spark.sql( + """select count(distinct code1a), count(distinct code1b), count(distinct code1c) - from ranged_data""").collect()[0] + from ranged_data""" + ).collect()[0] assert 100 == result[0] assert 100 == result[1] assert 100 == result[2] def test_basic_formatting1(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("val1", IntegerType(), unique_values=100) - .withColumn("val2", IntegerType(), minValue=1, maxValue=100) - .withColumn("str2", StringType(), format="test %s", baseColumn=["val1", "val2"], - base_column_type="values") - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, partitions=4) + .withIdOutput() + .withColumn("val1", IntegerType(), unique_values=100) + .withColumn("val2", IntegerType(), minValue=1, maxValue=100) + .withColumn("str2", StringType(), format="test %s", baseColumn=["val1", "val2"], base_column_type="values") + ) formattedDF = testDataSpec.build(withTempView=True) formattedDF.show() @@ -254,14 +279,13 @@ def test_basic_formatting1(self): assert rowCount == 100000 def test_basic_formatting2(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("val1", IntegerType(), unique_values=100) - .withColumn("val2", IntegerType(), minValue=1, maxValue=100) - .withColumn("str2", StringType(), format="test %s", baseColumn=["val1", "val2"], - base_column_type="hash") - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, partitions=4) + .withIdOutput() + .withColumn("val1", IntegerType(), unique_values=100) + .withColumn("val2", IntegerType(), minValue=1, maxValue=100) + .withColumn("str2", StringType(), format="test %s", baseColumn=["val1", "val2"], base_column_type="hash") + ) formattedDF = testDataSpec.build(withTempView=True) formattedDF.show() @@ -270,13 +294,13 @@ def test_basic_formatting2(self): assert rowCount == 100000 def test_basic_formatting_discrete_values(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("val1", IntegerType(), unique_values=100) - .withColumn("val2", IntegerType(), minValue=1, maxValue=100) - .withColumn("str6", StringType(), template=r"\v0 \v1", baseColumn=["val1", "val2"]) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, partitions=4) + .withIdOutput() + .withColumn("val1", IntegerType(), unique_values=100) + .withColumn("val2", IntegerType(), minValue=1, maxValue=100) + .withColumn("str6", StringType(), template=r"\v0 \v1", baseColumn=["val1", "val2"]) + ) formattedDF = testDataSpec.build(withTempView=True) formattedDF.show() @@ -285,16 +309,15 @@ def test_basic_formatting_discrete_values(self): assert rowCount == 100000 def test_basic_formatting3(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("val1", IntegerType(), unique_values=100) - .withColumn("val2", IntegerType(), minValue=1, maxValue=100) - - .withColumn("str5b", StringType(), format="test %s", baseColumn=["val1", "val2"], - values=["one", "two", "three"]) - - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, partitions=4) + .withIdOutput() + .withColumn("val1", IntegerType(), unique_values=100) + .withColumn("val2", IntegerType(), minValue=1, maxValue=100) + .withColumn( + "str5b", StringType(), format="test %s", baseColumn=["val1", "val2"], values=["one", "two", "three"] + ) + ) formattedDF = testDataSpec.build(withTempView=True) formattedDF.show() @@ -303,16 +326,14 @@ def test_basic_formatting3(self): assert rowCount == 100000 def test_basic_formatting3a(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("val1", IntegerType(), unique_values=100) - .withColumn("val2", IntegerType(), minValue=1, maxValue=100) - - # in this case values from base column are passed as array - .withColumn("str5b", StringType(), format="test %s", baseColumn=["val1", "val2"]) - - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, partitions=4) + .withIdOutput() + .withColumn("val1", IntegerType(), unique_values=100) + .withColumn("val2", IntegerType(), minValue=1, maxValue=100) + # in this case values from base column are passed as array + .withColumn("str5b", StringType(), format="test %s", baseColumn=["val1", "val2"]) + ) formattedDF = testDataSpec.build(withTempView=True) formattedDF.show() @@ -322,17 +343,16 @@ def test_basic_formatting3a(self): @pytest.mark.skip(reason="not yet implemented for multiple base columns") def test_basic_formatting4(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("val1", IntegerType(), unique_values=100) - .withColumn("val2", IntegerType(), minValue=1, maxValue=100) - - # when specifying multiple base columns - .withColumn("str5b", StringType(), format="test %s %s", baseColumn=["val1", "val2"], - base_column_type="values") - - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, partitions=4) + .withIdOutput() + .withColumn("val1", IntegerType(), unique_values=100) + .withColumn("val2", IntegerType(), minValue=1, maxValue=100) + # when specifying multiple base columns + .withColumn( + "str5b", StringType(), format="test %s %s", baseColumn=["val1", "val2"], base_column_type="values" + ) + ) formattedDF = testDataSpec.build(withTempView=True) formattedDF.show() @@ -341,23 +361,40 @@ def test_basic_formatting4(self): assert rowCount == 100000 def test_basic_formatting5(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("val1", IntegerType(), unique_values=100) - .withColumn("val2", IntegerType(), minValue=1, maxValue=100) - - .withColumn("str1", StringType(), format="test %s", baseColumn=["val1", "val2"], - values=["one", "two", "three"]) - .withColumn("str2", StringType(), format="test %s", baseColumn=["val1", "val2"], - values=["one", "two", "three"], weights=[3, 1, 1]) - - .withColumn("str3", StringType(), format="test %s", baseColumn=["val1", "val2"], - values=["one", "two", "three"], template=r"test \v0") - .withColumn("str4", StringType(), format="test %s", baseColumn=["val1", "val2"], - values=["one", "two", "three"], weights=[3, 1, 1], template=r"test \v0") - - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, partitions=4) + .withIdOutput() + .withColumn("val1", IntegerType(), unique_values=100) + .withColumn("val2", IntegerType(), minValue=1, maxValue=100) + .withColumn( + "str1", StringType(), format="test %s", baseColumn=["val1", "val2"], values=["one", "two", "three"] + ) + .withColumn( + "str2", + StringType(), + format="test %s", + baseColumn=["val1", "val2"], + values=["one", "two", "three"], + weights=[3, 1, 1], + ) + .withColumn( + "str3", + StringType(), + format="test %s", + baseColumn=["val1", "val2"], + values=["one", "two", "three"], + template=r"test \v0", + ) + .withColumn( + "str4", + StringType(), + format="test %s", + baseColumn=["val1", "val2"], + values=["one", "two", "three"], + weights=[3, 1, 1], + template=r"test \v0", + ) + ) formattedDF = testDataSpec.build(withTempView=True) formattedDF.show() @@ -366,25 +403,23 @@ def test_basic_formatting5(self): assert rowCount == 100000 def test_basic_formatting(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("val1", IntegerType(), unique_values=100) - .withColumn("val2", IntegerType(), minValue=1, maxValue=100) - .withColumn("str1", StringType(), format="test %d") - # .withColumn("str1a", StringType(), format="test %s") - .withColumn("str2", StringType(), format="test %s", baseColumn=["val1", "val2"], - base_column_type="values") - .withColumn("str3", StringType(), format="test %s", baseColumn=["val1", "val2"], - base_column_type="hash") - .withColumn("str4", StringType(), format="test %s", baseColumn=["val1", "val2"], - base_column_type="hash") - .withColumn("str5", StringType(), format="test %s", baseColumn=["val1", "val2"]) - .withColumn("str5a", StringType(), format="test %s", baseColumn=["val1", "val2"]) - .withColumn("str5b", StringType(), format="test %s", baseColumn=["val1", "val2"], - values=["one", "two", "three"]) - .withColumn("str6", StringType(), template=r"\v0 \v1", baseColumn=["val1", "val2"]) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, partitions=4) + .withIdOutput() + .withColumn("val1", IntegerType(), unique_values=100) + .withColumn("val2", IntegerType(), minValue=1, maxValue=100) + .withColumn("str1", StringType(), format="test %d") + # .withColumn("str1a", StringType(), format="test %s") + .withColumn("str2", StringType(), format="test %s", baseColumn=["val1", "val2"], base_column_type="values") + .withColumn("str3", StringType(), format="test %s", baseColumn=["val1", "val2"], base_column_type="hash") + .withColumn("str4", StringType(), format="test %s", baseColumn=["val1", "val2"], base_column_type="hash") + .withColumn("str5", StringType(), format="test %s", baseColumn=["val1", "val2"]) + .withColumn("str5a", StringType(), format="test %s", baseColumn=["val1", "val2"]) + .withColumn( + "str5b", StringType(), format="test %s", baseColumn=["val1", "val2"], values=["one", "two", "three"] + ) + .withColumn("str6", StringType(), template=r"\v0 \v1", baseColumn=["val1", "val2"]) + ) formattedDF = testDataSpec.build(withTempView=True) formattedDF.show() @@ -393,13 +428,13 @@ def test_basic_formatting(self): assert rowCount == 100000 def test_basic_prefix(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=1000, - partitions=4) - .withIdOutput() - .withColumn("val1", IntegerType(), unique_values=100) - .withColumn("val2", IntegerType(), minValue=1, maxValue=100) - .withColumn("val3", StringType(), values=["one", "two", "three"]) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=1000, partitions=4) + .withIdOutput() + .withColumn("val1", IntegerType(), unique_values=100) + .withColumn("val2", IntegerType(), minValue=1, maxValue=100) + .withColumn("val3", StringType(), values=["one", "two", "three"]) + ) formattedDF = testDataSpec.build(withTempView=True) formattedDF.show() @@ -430,22 +465,21 @@ def test_empty_range(self): assert empty_range.isEmpty() def test_reversed_ranges(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="ranged_data", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("val1", IntegerType(), minValue=100, maxValue=1, step=-1) - .withColumn("val2", IntegerType(), minValue=100, maxValue=1, step=-3, unique_values=5) - .withColumn("val3", IntegerType(), dataRange=NRange(100, 1, -1), unique_values=5) - .withColumn("val4", IntegerType(), minValue=1, maxValue=100, step=3, unique_values=5) - .withColumn("code1b", IntegerType(), minValue=1, maxValue=100) - .withColumn("code1c", IntegerType(), minValue=1, maxValue=200, unique_values=100) - .withColumn("code1d", IntegerType(), minValue=1, maxValue=200) - .withColumn("code2", IntegerType(), maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="ranged_data", rows=100000, partitions=4) + .withIdOutput() + .withColumn("val1", IntegerType(), minValue=100, maxValue=1, step=-1) + .withColumn("val2", IntegerType(), minValue=100, maxValue=1, step=-3, unique_values=5) + .withColumn("val3", IntegerType(), dataRange=NRange(100, 1, -1), unique_values=5) + .withColumn("val4", IntegerType(), minValue=1, maxValue=100, step=3, unique_values=5) + .withColumn("code1b", IntegerType(), minValue=1, maxValue=100) + .withColumn("code1c", IntegerType(), minValue=1, maxValue=200, unique_values=100) + .withColumn("code1d", IntegerType(), minValue=1, maxValue=200) + .withColumn("code2", IntegerType(), maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) rangedDF = testDataSpec.build() rangedDF.show() @@ -454,40 +488,40 @@ def test_reversed_ranges(self): assert rowCount == 100000 def test_date_time_ranges(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="ranged_data", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("last_sync_ts", "timestamp", - dataRange=DateRange("2017-10-01 00:00:00", - "2018-10-06 00:00:00", - "days=1,hours=1")) - .withColumn("last_sync_ts", "timestamp", - dataRange=DateRange("2017-10-01 00:00:00", - "2018-10-06 00:00:00", - "days=1,hours=1"), unique_values=5) - - .withColumn("last_sync_ts", "timestamp", - dataRange=DateRange("2017-10-01", - "2018-10-06", - "days=7", - datetime_format="%Y-%m-%d")) - - .withColumn("last_sync_dt1", DateType(), - dataRange=DateRange("2017-10-01 00:00:00", - "2018-10-06 00:00:00", - "days=1")) - .withColumn("last_sync_dt2", DateType(), - dataRange=DateRange("2017-10-01 00:00:00", - "2018-10-06 00:00:00", - "days=1"), unique_values=5) - - .withColumn("last_sync_date", DateType(), - dataRange=DateRange("2017-10-01", - "2018-10-06", - "days=7", - datetime_format="%Y-%m-%d")) - - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="ranged_data", rows=100000, partitions=4) + .withIdOutput() + .withColumn( + "last_sync_ts", + "timestamp", + dataRange=DateRange("2017-10-01 00:00:00", "2018-10-06 00:00:00", "days=1,hours=1"), + ) + .withColumn( + "last_sync_ts", + "timestamp", + dataRange=DateRange("2017-10-01 00:00:00", "2018-10-06 00:00:00", "days=1,hours=1"), + unique_values=5, + ) + .withColumn( + "last_sync_ts", + "timestamp", + dataRange=DateRange("2017-10-01", "2018-10-06", "days=7", datetime_format="%Y-%m-%d"), + ) + .withColumn( + "last_sync_dt1", DateType(), dataRange=DateRange("2017-10-01 00:00:00", "2018-10-06 00:00:00", "days=1") + ) + .withColumn( + "last_sync_dt2", + DateType(), + dataRange=DateRange("2017-10-01 00:00:00", "2018-10-06 00:00:00", "days=1"), + unique_values=5, + ) + .withColumn( + "last_sync_date", + DateType(), + dataRange=DateRange("2017-10-01", "2018-10-06", "days=7", datetime_format="%Y-%m-%d"), + ) + ) rangedDF = testDataSpec.build() rangedDF.show() @@ -499,25 +533,23 @@ def test_date_time_ranges(self): @pytest.mark.parametrize("asHtml", [True, False]) def test_script_table(self, asHtml): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("val1", IntegerType(), unique_values=100) - .withColumn("val2", IntegerType(), minValue=1, maxValue=100) - .withColumn("str1", StringType(), format="test %d") - # .withColumn("str1a", StringType(), format="test %s") - .withColumn("str2", StringType(), format="test %s", baseColumn=["val1", "val2"], - base_column_type="values") - .withColumn("str3", StringType(), format="test %s", baseColumn=["val1", "val2"], - base_column_type="hash") - .withColumn("str4", StringType(), format="test %s", baseColumn=["val1", "val2"], - base_column_type="hash") - .withColumn("str5", StringType(), format="test %s", baseColumn=["val1", "val2"]) - .withColumn("str5a", StringType(), format="test %s", baseColumn=["val1", "val2"]) - .withColumn("str5b", StringType(), format="test %s", baseColumn=["val1", "val2"], - values=["one", "two", "three"]) - .withColumn("str6", StringType(), template=r"\v0 \v1", baseColumn=["val1", "val2"]) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, partitions=4) + .withIdOutput() + .withColumn("val1", IntegerType(), unique_values=100) + .withColumn("val2", IntegerType(), minValue=1, maxValue=100) + .withColumn("str1", StringType(), format="test %d") + # .withColumn("str1a", StringType(), format="test %s") + .withColumn("str2", StringType(), format="test %s", baseColumn=["val1", "val2"], base_column_type="values") + .withColumn("str3", StringType(), format="test %s", baseColumn=["val1", "val2"], base_column_type="hash") + .withColumn("str4", StringType(), format="test %s", baseColumn=["val1", "val2"], base_column_type="hash") + .withColumn("str5", StringType(), format="test %s", baseColumn=["val1", "val2"]) + .withColumn("str5a", StringType(), format="test %s", baseColumn=["val1", "val2"]) + .withColumn( + "str5b", StringType(), format="test %s", baseColumn=["val1", "val2"], values=["one", "two", "three"] + ) + .withColumn("str6", StringType(), template=r"\v0 \v1", baseColumn=["val1", "val2"]) + ) script = testDataSpec.scriptTable(name="Test", asHtml=asHtml) print(script) @@ -526,8 +558,19 @@ def test_script_table(self, asHtml): output_columns = testDataSpec.getOutputColumnNames() print(output_columns) - assert set(output_columns) == {'id', 'val1', 'val2', 'str1', 'str2', 'str3', 'str4', 'str5', - 'str5a', 'str5b', 'str6'} + assert set(output_columns) == { + 'id', + 'val1', + 'val2', + 'str1', + 'str2', + 'str3', + 'str4', + 'str5', + 'str5a', + 'str5b', + 'str6', + } assert script is not None @@ -538,35 +581,49 @@ def test_script_table(self, asHtml): @pytest.mark.parametrize("asHtml", [True, False]) def test_script_merge1(self, asHtml): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("val1", IntegerType(), unique_values=100) - .withColumn("val2", IntegerType(), minValue=1, maxValue=100) - .withColumn("str1", StringType(), format="test %d") - # .withColumn("str1a", StringType(), format="test %s") - .withColumn("str2", StringType(), format="test %s", baseColumn=["val1", "val2"], - base_column_type="values") - .withColumn("str3", StringType(), format="test %s", baseColumn=["val1", "val2"], - base_column_type="hash") - .withColumn("str4", StringType(), format="test %s", baseColumn=["val1", "val2"], - baseColumnType="hash") - .withColumn("str5", StringType(), format="test %s", baseColumn=["val1", "val2"]) - .withColumn("str5a", StringType(), format="test %s", baseColumn=["val1", "val2"]) - .withColumn("action", StringType(), format="test %s", baseColumn=["val1", "val2"], - values=["INS", "DEL", "UPDATE"]) - .withColumn("str6", StringType(), template=r"\v0 \v1", baseColumn=["val1", "val2"]) - ) - - script = testDataSpec.scriptMerge(tgtName="Test", srcName="TestInc", joinExpr="src.id=tgt.id", - delExpr="src.action='DEL'", updateExpr="src.action='UPDATE", - asHtml=asHtml) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, partitions=4) + .withIdOutput() + .withColumn("val1", IntegerType(), unique_values=100) + .withColumn("val2", IntegerType(), minValue=1, maxValue=100) + .withColumn("str1", StringType(), format="test %d") + # .withColumn("str1a", StringType(), format="test %s") + .withColumn("str2", StringType(), format="test %s", baseColumn=["val1", "val2"], base_column_type="values") + .withColumn("str3", StringType(), format="test %s", baseColumn=["val1", "val2"], base_column_type="hash") + .withColumn("str4", StringType(), format="test %s", baseColumn=["val1", "val2"], baseColumnType="hash") + .withColumn("str5", StringType(), format="test %s", baseColumn=["val1", "val2"]) + .withColumn("str5a", StringType(), format="test %s", baseColumn=["val1", "val2"]) + .withColumn( + "action", StringType(), format="test %s", baseColumn=["val1", "val2"], values=["INS", "DEL", "UPDATE"] + ) + .withColumn("str6", StringType(), template=r"\v0 \v1", baseColumn=["val1", "val2"]) + ) + + script = testDataSpec.scriptMerge( + tgtName="Test", + srcName="TestInc", + joinExpr="src.id=tgt.id", + delExpr="src.action='DEL'", + updateExpr="src.action='UPDATE", + asHtml=asHtml, + ) print(script) output_columns = testDataSpec.getOutputColumnNames() print(output_columns) - assert set(output_columns) == {'id', 'val1', 'val2', - 'str1', 'str2', 'str3', 'str4', 'str5', 'str5a', 'action', 'str6'} + assert set(output_columns) == { + 'id', + 'val1', + 'val2', + 'str1', + 'str2', + 'str3', + 'str4', + 'str5', + 'str5a', + 'action', + 'str6', + } assert script is not None @@ -578,25 +635,23 @@ def test_script_merge1(self, asHtml): assert col in script def test_script_merge_min(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("val1", IntegerType(), unique_values=100) - .withColumn("val2", IntegerType(), minValue=1, maxValue=100) - .withColumn("str1", StringType(), format="test %d") - # .withColumn("str1a", StringType(), format="test %s") - .withColumn("str2", StringType(), format="test %s", baseColumn=["val1", "val2"], - base_column_type="values") - .withColumn("str3", StringType(), format="test %s", baseColumn=["val1", "val2"], - base_column_type="hash") - .withColumn("str4", StringType(), format="test %s", baseColumn=["val1", "val2"], - base_column_type="hash") - .withColumn("str5", StringType(), format="test %s", baseColumn=["val1", "val2"]) - .withColumn("str5a", StringType(), format="test %s", baseColumn=["val1", "val2"]) - .withColumn("action", StringType(), format="test %s", baseColumn=["val1", "val2"], - values=["INS", "DEL", "UPDATE"]) - .withColumn("str6", StringType(), template=r"\v0 \v1", baseColumn=["val1", "val2"]) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="formattedDF", rows=100000, partitions=4) + .withIdOutput() + .withColumn("val1", IntegerType(), unique_values=100) + .withColumn("val2", IntegerType(), minValue=1, maxValue=100) + .withColumn("str1", StringType(), format="test %d") + # .withColumn("str1a", StringType(), format="test %s") + .withColumn("str2", StringType(), format="test %s", baseColumn=["val1", "val2"], base_column_type="values") + .withColumn("str3", StringType(), format="test %s", baseColumn=["val1", "val2"], base_column_type="hash") + .withColumn("str4", StringType(), format="test %s", baseColumn=["val1", "val2"], base_column_type="hash") + .withColumn("str5", StringType(), format="test %s", baseColumn=["val1", "val2"]) + .withColumn("str5a", StringType(), format="test %s", baseColumn=["val1", "val2"]) + .withColumn( + "action", StringType(), format="test %s", baseColumn=["val1", "val2"], values=["INS", "DEL", "UPDATE"] + ) + .withColumn("str6", StringType(), template=r"\v0 \v1", baseColumn=["val1", "val2"]) + ) script = testDataSpec.scriptMerge(tgtName="Test", srcName="TestInc", joinExpr="src.id=tgt.id") assert script is not None @@ -605,8 +660,19 @@ def test_script_merge_min(self): output_columns = testDataSpec.getOutputColumnNames() print(output_columns) - assert set(output_columns) == \ - {'id', 'val1', 'val2', 'str1', 'str2', 'str3', 'str4', 'str5', 'str5a', 'action', 'str6'} + assert set(output_columns) == { + 'id', + 'val1', + 'val2', + 'str1', + 'str2', + 'str3', + 'str4', + 'str5', + 'str5a', + 'action', + 'str6', + } assert script is not None @@ -618,13 +684,13 @@ def test_script_merge_min(self): assert col in script def test_strings_from_numeric_string_field1(self): - """ Check that order_id always generates a non null value when using random values""" - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="stringsFromNumbers", rows=100000, - partitions=4) - .withIdOutput() - .withColumn("order_num", minValue=1, maxValue=100000000, random=True) - .withColumn("order_id", prefix="order", baseColumn="order_num") - ) + """Check that order_id always generates a non null value when using random values""" + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="stringsFromNumbers", rows=100000, partitions=4) + .withIdOutput() + .withColumn("order_num", minValue=1, maxValue=100000000, random=True) + .withColumn("order_id", prefix="order", baseColumn="order_num") + ) testDataSpec.build(withTempView=True) @@ -638,14 +704,14 @@ def test_strings_from_numeric_string_field1(self): assert rowCount == 0 def test_strings_from_numeric_string_field2(self): - """ Check that order_id always generates a non null value when using non-random values""" - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="stringsFromNumbers", rows=100000, - partitions=4) - .withIdOutput() - # use step of -1 to ensure descending from max value - .withColumn("order_num", minValue=1, maxValue=100000000, step=-1) - .withColumn("order_id", prefix="order", baseColumn="order_num") - ) + """Check that order_id always generates a non null value when using non-random values""" + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="stringsFromNumbers", rows=100000, partitions=4) + .withIdOutput() + # use step of -1 to ensure descending from max value + .withColumn("order_num", minValue=1, maxValue=100000000, step=-1) + .withColumn("order_id", prefix="order", baseColumn="order_num") + ) testDataSpec.build(withTempView=True) @@ -659,14 +725,14 @@ def test_strings_from_numeric_string_field2(self): assert rowCount == 0 def test_strings_from_numeric_string_field2a(self): - """ Check that order_id always generates a non null value when using non-random values""" - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="stringsFromNumbers", rows=100000, - partitions=4) - .withIdOutput() - # use step of -1 to ensure descending from max value - .withColumn("order_num", minValue=1, maxValue=100000000, step=-1) - .withColumn("order_id", "string", minValue=None, suffix="_order", baseColumn="order_num") - ) + """Check that order_id always generates a non null value when using non-random values""" + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="stringsFromNumbers", rows=100000, partitions=4) + .withIdOutput() + # use step of -1 to ensure descending from max value + .withColumn("order_num", minValue=1, maxValue=100000000, step=-1) + .withColumn("order_id", "string", minValue=None, suffix="_order", baseColumn="order_num") + ) testDataSpec.build(withTempView=True) @@ -684,13 +750,13 @@ def test_strings_from_numeric_string_field2a(self): assert rowCount == 0 def test_strings_from_numeric_string_field3(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="stringsFromNumbers", rows=100000, - partitions=4) - .withIdOutput() - # default column type is string - .withColumn("order_num", minValue=1, maxValue=100000000, random=True) - .withColumn("order_id", prefix="order", baseColumn="order_num") - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="stringsFromNumbers", rows=100000, partitions=4) + .withIdOutput() + # default column type is string + .withColumn("order_num", minValue=1, maxValue=100000000, random=True) + .withColumn("order_id", prefix="order", baseColumn="order_num") + ) testDataSpec.build(withTempView=True) @@ -704,13 +770,13 @@ def test_strings_from_numeric_string_field3(self): assert rowCount == 0 def test_strings_from_numeric_string_field4(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="stringsFromNumbers", rows=100000, - partitions=4) - .withIdOutput() - # default column type is string - .withColumn("order_num", minValue=1, maxValue=100000000, step=-1) - .withColumn("order_id", prefix="order", baseColumn="order_num") - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="stringsFromNumbers", rows=100000, partitions=4) + .withIdOutput() + # default column type is string + .withColumn("order_num", minValue=1, maxValue=100000000, step=-1) + .withColumn("order_id", prefix="order", baseColumn="order_num") + ) df = testDataSpec.build(withTempView=True) @@ -721,30 +787,34 @@ def test_strings_from_numeric_string_field4(self): rowCount = nullRowsDF.count() assert rowCount == 0 - @pytest.mark.parametrize("columnSpecOptions", [ - {"dataType": "byte", "minValue": 1, "maxValue": None}, - {"dataType": "byte", "minValue": None, "maxValue": 10}, - {"dataType": "short", "minValue": 1, "maxValue": None}, - {"dataType": "short", "minValue": None, "maxValue": 100}, - {"dataType": "integer", "minValue": 1, "maxValue": None}, - {"dataType": "integer", "minValue": None, "maxValue": 100}, - {"dataType": "long", "minValue": 1, "maxValue": None}, - {"dataType": "long", "minValue": None, "maxValue": 100}, - {"dataType": "float", "minValue": 1.0, "maxValue": None}, - {"dataType": "float", "minValue": None, "maxValue": 100.0}, - {"dataType": "double", "minValue": 1, "maxValue": None}, - {"dataType": "double", "minValue": None, "maxValue": 100.0} - ]) + @pytest.mark.parametrize( + "columnSpecOptions", + [ + {"dataType": "byte", "minValue": 1, "maxValue": None}, + {"dataType": "byte", "minValue": None, "maxValue": 10}, + {"dataType": "short", "minValue": 1, "maxValue": None}, + {"dataType": "short", "minValue": None, "maxValue": 100}, + {"dataType": "integer", "minValue": 1, "maxValue": None}, + {"dataType": "integer", "minValue": None, "maxValue": 100}, + {"dataType": "long", "minValue": 1, "maxValue": None}, + {"dataType": "long", "minValue": None, "maxValue": 100}, + {"dataType": "float", "minValue": 1.0, "maxValue": None}, + {"dataType": "float", "minValue": None, "maxValue": 100.0}, + {"dataType": "double", "minValue": 1, "maxValue": None}, + {"dataType": "double", "minValue": None, "maxValue": 100.0}, + ], + ) def test_random_generation_without_range_values(self, columnSpecOptions): dataType = columnSpecOptions.get("dataType", None) minValue = columnSpecOptions.get("minValue", None) maxValue = columnSpecOptions.get("maxValue", None) - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="randomGenerationWithoutRangeValues", rows=100, - partitions=4) - .withIdOutput() - # default column type is string - .withColumn("randCol", colType=dataType, minValue=minValue, maxValue=maxValue, random=True) - ) + testDataSpec = ( + dg.DataGenerator( + sparkSession=spark, name="randomGenerationWithoutRangeValues", rows=100, partitions=4 + ).withIdOutput() + # default column type is string + .withColumn("randCol", colType=dataType, minValue=minValue, maxValue=maxValue, random=True) + ) df = testDataSpec.build(withTempView=True) sortedDf = df.orderBy("randCol") diff --git a/tests/test_ranged_values_and_dates.py b/tests/test_ranged_values_and_dates.py index a1d66693..e6453895 100644 --- a/tests/test_ranged_values_and_dates.py +++ b/tests/test_ranged_values_and_dates.py @@ -20,9 +20,7 @@ def setUp(self): print("setting up") def test_date_range_object(self): - x = DateRange("2017-10-01 00:00:00", - "2018-10-06 11:55:00", - "days=7") + x = DateRange("2017-10-01 00:00:00", "2018-10-06 11:55:00", "days=7") print("date range", x) print("minValue", datetime.fromtimestamp(x.minValue)) print("maxValue", datetime.fromtimestamp(x.maxValue)) @@ -61,11 +59,12 @@ def test_basic_dates(self): start = datetime(2017, 10, 1, 0, 0, 0) end = datetime(2018, 10, 1, 6, 0, 0) - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) - .withIdOutput() - .withColumn("last_sync_dt", "timestamp", begin=start, end=end, interval=interval, random=True) - .build() - ) + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn("last_sync_dt", "timestamp", begin=start, end=end, interval=interval, random=True) + .build() + ) self.assertIsNotNone(testDataDF.schema) self.assertIs(type(testDataDF.schema.fields[1].dataType), type(TimestampType())) @@ -88,11 +87,12 @@ def test_basic_dates_non_random(self): start = datetime(2017, 10, 1, 0, 0, 0) end = datetime(2018, 10, 1, 6, 0, 0) - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) - .withIdOutput() - .withColumn("last_sync_dt", "timestamp", begin=start, end=end, interval=interval) - .build() - ) + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn("last_sync_dt", "timestamp", begin=start, end=end, interval=interval) + .build() + ) self.assertIsNotNone(testDataDF.schema) self.assertIs(type(testDataDF.schema.fields[1].dataType), type(TimestampType())) @@ -111,14 +111,15 @@ def test_basic_dates_non_random(self): def test_basic_dates_minimal(self): '''test dates with just unique values''' - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=10000, partitions=4) - .withIdOutput() - .withColumn("last_sync_dt", "date", unique_values=100, random=True) - .withColumn("last_sync_dt2", "date", unique_values=100, base_column_type="values") - .withColumn("last_sync_dt3", "date", unique_values=300, base_column_type="values") - .withColumn("last_sync_dt4", "date", unique_values=300, random=True) - .build() - ) + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=10000, partitions=4) + .withIdOutput() + .withColumn("last_sync_dt", "date", unique_values=100, random=True) + .withColumn("last_sync_dt2", "date", unique_values=100, base_column_type="values") + .withColumn("last_sync_dt3", "date", unique_values=300, base_column_type="values") + .withColumn("last_sync_dt4", "date", unique_values=300, random=True) + .build() + ) self.assertIsNotNone(testDataDF.schema) self.assertIs(type(testDataDF.schema.fields[1].dataType), type(DateType())) @@ -127,15 +128,16 @@ def test_basic_dates_minimal(self): self.assertIs(type(testDataDF.schema.fields[4].dataType), type(DateType())) # validation statements - df_min_and_max = testDataDF.agg(F.min("last_sync_dt").alias("min_dt1"), - F.max("last_sync_dt").alias("max_dt1"), - F.min("last_sync_dt2").alias("min_dt2"), - F.max("last_sync_dt2").alias("max_dt2"), - F.min("last_sync_dt3").alias("min_dt3"), - F.max("last_sync_dt3").alias("max_dt3"), - F.min("last_sync_dt4").alias("min_dt4"), - F.max("last_sync_dt4").alias("max_dt4"), - ) + df_min_and_max = testDataDF.agg( + F.min("last_sync_dt").alias("min_dt1"), + F.max("last_sync_dt").alias("max_dt1"), + F.min("last_sync_dt2").alias("min_dt2"), + F.max("last_sync_dt2").alias("max_dt2"), + F.min("last_sync_dt3").alias("min_dt3"), + F.max("last_sync_dt3").alias("max_dt3"), + F.min("last_sync_dt4").alias("min_dt4"), + F.max("last_sync_dt4").alias("max_dt4"), + ) min_and_max = df_min_and_max.collect()[0] self.assertGreaterEqual(min_and_max['min_dt1'], DateRange.DEFAULT_START_DATE) @@ -147,14 +149,15 @@ def test_basic_dates_minimal(self): self.assertLessEqual(min_and_max['max_dt3'], DateRange.DEFAULT_END_DATE) self.assertLessEqual(min_and_max['max_dt4'], DateRange.DEFAULT_END_DATE) - count_distinct = testDataDF.select(F.countDistinct("last_sync_dt"), - F.countDistinct("last_sync_dt2"), - F.countDistinct("last_sync_dt3"), - F.countDistinct("last_sync_dt4"), - ).collect()[0] - self.assertLessEqual( count_distinct[0], 100) - self.assertLessEqual( count_distinct[1], 100) - self.assertLessEqual( count_distinct[2], 300) + count_distinct = testDataDF.select( + F.countDistinct("last_sync_dt"), + F.countDistinct("last_sync_dt2"), + F.countDistinct("last_sync_dt3"), + F.countDistinct("last_sync_dt4"), + ).collect()[0] + self.assertLessEqual(count_distinct[0], 100) + self.assertLessEqual(count_distinct[1], 100) + self.assertLessEqual(count_distinct[2], 300) self.assertLessEqual(count_distinct[3], 300) def test_date_range1(self): @@ -163,14 +166,13 @@ def test_date_range1(self): start = datetime(2017, 10, 1, 0, 0, 0) end = datetime(2018, 10, 1, 6, 0, 0) - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) - .withIdOutput() - .withColumn("last_sync_dt", "timestamp", begin=start, end=end, interval=interval, random=True) - .withColumn("last_sync_dt1", "timestamp", - dataRange=DateRange(start, end, interval), random=True) - - .build() - ) + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn("last_sync_dt", "timestamp", begin=start, end=end, interval=interval, random=True) + .withColumn("last_sync_dt1", "timestamp", dataRange=DateRange(start, end, interval), random=True) + .build() + ) self.assertIsNotNone(testDataDF.schema) self.assertIs(type(testDataDF.schema.fields[2].dataType), type(TimestampType())) @@ -188,19 +190,21 @@ def test_date_range1(self): self.assertLessEqual(10, count_distinct) def test_date_range2(self): - #interval = timedelta(days=1, hours=1) + # interval = timedelta(days=1, hours=1) start = datetime(2017, 10, 1, 0, 0, 0) end = datetime(2018, 10, 6, 0, 0, 0) - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) - .withIdOutput() - .withColumn("last_sync_dt1", "timestamp", - dataRange=DateRange("2017-10-01 00:00:00", - "2018-10-06 00:00:00", - "days=1,hours=1"), random=True) - - .build() - ) + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn( + "last_sync_dt1", + "timestamp", + dataRange=DateRange("2017-10-01 00:00:00", "2018-10-06 00:00:00", "days=1,hours=1"), + random=True, + ) + .build() + ) self.assertIsNotNone(testDataDF.schema) self.assertIs(type(testDataDF.schema.fields[1].dataType), type(TimestampType())) @@ -218,22 +222,25 @@ def test_date_range3(self): start = date(2017, 10, 1) end = date(2018, 10, 6) - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) - .withIdOutput() - .withColumn("last_sync_date", "date", - dataRange=DateRange("2017-10-01 00:00:00", - "2018-10-06 11:55:00", - "days=7"), random=True) - - .build() - ) + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn( + "last_sync_date", + "date", + dataRange=DateRange("2017-10-01 00:00:00", "2018-10-06 11:55:00", "days=7"), + random=True, + ) + .build() + ) self.assertIsNotNone(testDataDF.schema) self.assertIs(type(testDataDF.schema.fields[1].dataType), type(DateType())) # validation statements - df_min_and_max = testDataDF.agg(F.min("last_sync_date").alias("min_dt"), - F.max("last_sync_date").alias("max_dt")) + df_min_and_max = testDataDF.agg( + F.min("last_sync_date").alias("min_dt"), F.max("last_sync_date").alias("max_dt") + ) min_and_max = df_min_and_max.collect()[0] min_dt = min_and_max['min_dt'] @@ -248,15 +255,14 @@ def test_date_range3(self): self.assertEqual(df_outside2.count(), 0) def test_date_range3a(self): - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) - .withIdOutput() - .withColumn("last_sync_date", "date", - dataRange=DateRange("2017-10-01 00:00:00", - "2018-10-06 00:00:00", - "days=7")) - - .build() - ) + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn( + "last_sync_date", "date", dataRange=DateRange("2017-10-01 00:00:00", "2018-10-06 00:00:00", "days=7") + ) + .build() + ) print("schema", testDataDF.schema) testDataDF.printSchema() @@ -272,16 +278,17 @@ def test_date_range3a(self): self.assertEqual(df_outside2.count(), 0) def test_date_range4(self): - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) - .withIdOutput() - .withColumn("last_sync_date", "date", - dataRange=DateRange("2017-10-01", - "2018-10-06", - "days=7", - datetime_format="%Y-%m-%d"), random=True) - - .build() - ) + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn( + "last_sync_date", + "date", + dataRange=DateRange("2017-10-01", "2018-10-06", "days=7", datetime_format="%Y-%m-%d"), + random=True, + ) + .build() + ) print("schema", testDataDF.schema) testDataDF.printSchema() @@ -297,16 +304,16 @@ def test_date_range4(self): self.assertEqual(df_outside2.count(), 0) def test_date_range4a(self): - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) - .withIdOutput() - .withColumn("last_sync_date", "date", - dataRange=DateRange("2017-10-01", - "2018-10-06", - "days=7", - datetime_format="%Y-%m-%d")) - - .build() - ) + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn( + "last_sync_date", + "date", + dataRange=DateRange("2017-10-01", "2018-10-06", "days=7", datetime_format="%Y-%m-%d"), + ) + .build() + ) print("schema", testDataDF.schema) testDataDF.printSchema() @@ -323,15 +330,17 @@ def test_date_range4a(self): # @unittest.skip("not yet finalized") def test_timestamp_range3(self): - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) - .withIdOutput() - .withColumn("last_sync_date", "timestamp", - dataRange=DateRange("2017-10-01 00:00:00", - "2018-10-06 00:00:00", - "days=7"), random=True) - - .build() - ) + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn( + "last_sync_date", + "timestamp", + dataRange=DateRange("2017-10-01 00:00:00", "2018-10-06 00:00:00", "days=7"), + random=True, + ) + .build() + ) print("schema", testDataDF.schema) testDataDF.printSchema() @@ -347,15 +356,16 @@ def test_timestamp_range3(self): self.assertEqual(df_outside2.count(), 0) def test_timestamp_range3a(self): - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) - .withIdOutput() - .withColumn("last_sync_date", "timestamp", - dataRange=DateRange("2017-10-01 00:00:00", - "2018-10-06 00:00:00", - "days=7")) - - .build() - ) + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn( + "last_sync_date", + "timestamp", + dataRange=DateRange("2017-10-01 00:00:00", "2018-10-06 00:00:00", "days=7"), + ) + .build() + ) print("schema", testDataDF.schema) testDataDF.printSchema() @@ -371,16 +381,17 @@ def test_timestamp_range3a(self): self.assertEqual(df_outside2.count(), 0) def test_timestamp_range4(self): - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) - .withIdOutput() - .withColumn("last_sync_date", "timestamp", - dataRange=DateRange("2017-10-01", - "2018-10-06", - "days=7", - datetime_format="%Y-%m-%d"), random=True) - - .build() - ) + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn( + "last_sync_date", + "timestamp", + dataRange=DateRange("2017-10-01", "2018-10-06", "days=7", datetime_format="%Y-%m-%d"), + random=True, + ) + .build() + ) print("schema", testDataDF.schema) testDataDF.printSchema() @@ -396,16 +407,16 @@ def test_timestamp_range4(self): self.assertEqual(df_outside2.count(), 0) def test_timestamp_range4a(self): - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) - .withIdOutput() - .withColumn("last_sync_date", "timestamp", - dataRange=DateRange("2017-10-01", - "2018-10-06", - "days=7", - datetime_format="%Y-%m-%d")) - - .build() - ) + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn( + "last_sync_date", + "timestamp", + dataRange=DateRange("2017-10-01", "2018-10-06", "days=7", datetime_format="%Y-%m-%d"), + ) + .build() + ) print("schema", testDataDF.schema) testDataDF.printSchema() @@ -421,17 +432,17 @@ def test_timestamp_range4a(self): self.assertEqual(df_outside2.count(), 0) def test_unique_values1(self): - testDataDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) - .withIdOutput() - .withColumn("code1", "int", unique_values=7) - .withColumn("code2", "int", unique_values=7, minValue=20) - .build() - ) - - testDataSummary = testDataDF.selectExpr("min(code1) as min_c1", - "max(code1) as max_c1", - "min(code2) as min_c2", - "max(code2) as max_c2") + testDataDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, partitions=4) + .withIdOutput() + .withColumn("code1", "int", unique_values=7) + .withColumn("code2", "int", unique_values=7, minValue=20) + .build() + ) + + testDataSummary = testDataDF.selectExpr( + "min(code1) as min_c1", "max(code1) as max_c1", "min(code2) as min_c2", "max(code2) as max_c2" + ) summary = testDataSummary.collect()[0] self.assertEqual(summary[0], 1) @@ -440,11 +451,12 @@ def test_unique_values1(self): self.assertEqual(summary[3], 26) def test_unique_values_ts(self): - testDataUniqueDF = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4) - .withIdOutput() - .withColumn("test_ts", "timestamp", unique_values=51, random=True) - .build() - ) + testDataUniqueDF = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4) + .withIdOutput() + .withColumn("test_ts", "timestamp", unique_values=51, random=True) + .build() + ) testDataUniqueDF.createOrReplaceTempView("testUnique1") @@ -453,11 +465,12 @@ def test_unique_values_ts(self): self.assertEqual(summary[0], 51) def test_unique_values_ts2(self): - df_unique_ts2 = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4) - .withIdOutput() - .withColumn("test_ts", "timestamp", unique_values=51) - .build() - ) + df_unique_ts2 = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4) + .withIdOutput() + .withColumn("test_ts", "timestamp", unique_values=51) + .build() + ) df_unique_ts2.createOrReplaceTempView("testUnique2") @@ -468,12 +481,15 @@ def test_unique_values_ts2(self): def test_unique_values_ts3(self): testDataUniqueTSDF = ( dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4) - .withIdOutput() - .withColumn("test_ts", "timestamp", unique_values=51, random=True, - dataRange=DateRange("2017-10-01 00:00:00", - "2018-10-06 00:00:00", - "minutes=10")) - .build() + .withIdOutput() + .withColumn( + "test_ts", + "timestamp", + unique_values=51, + random=True, + dataRange=DateRange("2017-10-01 00:00:00", "2018-10-06 00:00:00", "minutes=10"), + ) + .build() ) testDataUniqueTSDF.createOrReplaceTempView("testUniqueTS3") @@ -486,10 +502,17 @@ def test_unique_values_ts4(self): df_unique_ts4 = ( dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4) - .withIdOutput() - .withColumn("test_ts", "timestamp", unique_values=51, random=True, - begin="2017-10-01 00:00:00", end="2018-10-06 23:59:59", interval="minutes=10") - .build() + .withIdOutput() + .withColumn( + "test_ts", + "timestamp", + unique_values=51, + random=True, + begin="2017-10-01 00:00:00", + end="2018-10-06 23:59:59", + interval="minutes=10", + ) + .build() ) df_unique_ts4.createOrReplaceTempView("testUniqueTS4") @@ -501,8 +524,8 @@ def test_unique_values_ts4(self): def test_unique_values_date(self): testDataUniqueDF3spec = ( dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4) - .withIdOutput() - .withColumn("test_ts", "date", unique_values=51, interval="1 days") + .withIdOutput() + .withColumn("test_ts", "date", unique_values=51, interval="1 days") ) testDataUniqueDF3 = testDataUniqueDF3spec.build() @@ -515,12 +538,13 @@ def test_unique_values_date(self): self.assertEqual(summary[0], 51) def test_unique_values_date2(self): - ''' Check for unique dates''' - df_unique_date2 = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4) - .withIdOutput() - .withColumn("test_ts", "date", unique_values=51, random=True) - .build() - ) + '''Check for unique dates''' + df_unique_date2 = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4) + .withIdOutput() + .withColumn("test_ts", "date", unique_values=51, random=True) + .build() + ) df_unique_date2.createOrReplaceTempView("testUnique4") @@ -529,13 +553,20 @@ def test_unique_values_date2(self): self.assertEqual(summary[0], 51) def test_unique_values_date3(self): - ''' Check for unique dates when begin, end and interval are specified''' + '''Check for unique dates when begin, end and interval are specified''' df_unique_date3 = ( dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4) - .withIdOutput() - .withColumn("test_ts", "date", unique_values=51, random=True, begin="2017-10-01", end="2018-10-06", - interval="days=2") - .build() + .withIdOutput() + .withColumn( + "test_ts", + "date", + unique_values=51, + random=True, + begin="2017-10-01", + end="2018-10-06", + interval="days=2", + ) + .build() ) df_unique_date3.createOrReplaceTempView("testUnique4a") @@ -545,13 +576,20 @@ def test_unique_values_date3(self): self.assertEqual(summary[0], 51) def test_unique_values_date3a(self): - ''' Check for unique dates when begin, end and interval are specified''' + '''Check for unique dates when begin, end and interval are specified''' df_unique_date3 = ( dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4) - .withIdOutput() - .withColumn("test_ts", "date", unique_values=51, random=True, begin="2017-10-01", end="2018-10-06", - interval="days=1") - .build() + .withIdOutput() + .withColumn( + "test_ts", + "date", + unique_values=51, + random=True, + begin="2017-10-01", + end="2018-10-06", + interval="days=1", + ) + .build() ) df_unique_date3.createOrReplaceTempView("testUnique4a") @@ -563,21 +601,22 @@ def test_unique_values_date3a(self): def test_unique_values_integers(self): testDataUniqueIntegersDF = ( dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4) - .withIdOutput() - .withColumn("val1", "int", unique_values=51, random=True) - .withColumn("val2", "int", unique_values=57) - .withColumn("val3", "long", unique_values=93) - .withColumn("val4", "long", unique_values=87, random=True) - .withColumn("val5", "short", unique_values=93) - .withColumn("val6", "short", unique_values=87, random=True) - .withColumn("val7", "byte", unique_values=93) - .withColumn("val8", "byte", unique_values=87, random=True) - .build() + .withIdOutput() + .withColumn("val1", "int", unique_values=51, random=True) + .withColumn("val2", "int", unique_values=57) + .withColumn("val3", "long", unique_values=93) + .withColumn("val4", "long", unique_values=87, random=True) + .withColumn("val5", "short", unique_values=93) + .withColumn("val6", "short", unique_values=87, random=True) + .withColumn("val7", "byte", unique_values=93) + .withColumn("val8", "byte", unique_values=87, random=True) + .build() ) testDataUniqueIntegersDF.createOrReplaceTempView("testUniqueIntegers") - dfResults = spark.sql(""" + dfResults = spark.sql( + """ select count(distinct val1), count(distinct val2), count(distinct val3), count(distinct val4), count(distinct val5), @@ -585,7 +624,9 @@ def test_unique_values_integers(self): count(distinct val7), count(distinct val8) from testUniqueIntegers - """"") + """ + "" + ) summary = dfResults.collect()[0] self.assertEqual(summary[0], 51) self.assertEqual(summary[1], 57) @@ -600,23 +641,26 @@ def test_unique_values_integers(self): def test_unique_values_decimal(self): testDataUniqueDecimalsDF = ( dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4) - .withIdOutput() - .withColumn("val1", "decimal(15,5)", unique_values=51, random=True) - .withColumn("val2", "decimal(15,5)", unique_values=57) - .withColumn("val3", "decimal(10,4)", unique_values=93) - .withColumn("val4", "decimal(10,0)", unique_values=87, random=True) - .build() + .withIdOutput() + .withColumn("val1", "decimal(15,5)", unique_values=51, random=True) + .withColumn("val2", "decimal(15,5)", unique_values=57) + .withColumn("val3", "decimal(10,4)", unique_values=93) + .withColumn("val4", "decimal(10,0)", unique_values=87, random=True) + .build() ) testDataUniqueDecimalsDF.createOrReplaceTempView("testUniqueDecimal") - dfResults = spark.sql(""" + dfResults = spark.sql( + """ select count(distinct val1), count(distinct val2), count(distinct val3), count(distinct val4) from testUniqueDecimal - """"") + """ + "" + ) summary = dfResults.collect()[0] self.assertEqual(summary[0], 51) self.assertEqual(summary[1], 57) @@ -627,23 +671,26 @@ def test_unique_values_decimal(self): def test_unique_values_float(self): testDataUniqueFloatssDF = ( dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4) - .withIdOutput() - .withColumn("val1", "float", unique_values=51, random=True) - .withColumn("val2", "float", unique_values=57) - .withColumn("val3", "double", unique_values=93) - .withColumn("val4", "double", unique_values=87, random=True) - .build() + .withIdOutput() + .withColumn("val1", "float", unique_values=51, random=True) + .withColumn("val2", "float", unique_values=57) + .withColumn("val3", "double", unique_values=93) + .withColumn("val4", "double", unique_values=87, random=True) + .build() ) testDataUniqueFloatssDF.createOrReplaceTempView("testUniqueFloats") - dfResults = spark.sql(""" + dfResults = spark.sql( + """ select count(distinct val1), count(distinct val2), count(distinct val3), count(distinct val4) from testUniqueFloats - """"") + """ + "" + ) summary = dfResults.collect()[0] self.assertEqual(summary[0], 51) self.assertEqual(summary[1], 57) @@ -653,27 +700,31 @@ def test_unique_values_float(self): def test_unique_values_float2(self): df_unique_float2 = ( - dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=4, verbose=True, - debug=True) - .withIdOutput() - .withColumn("val1", "float", unique_values=51, random=True, minValue=1.0) - .withColumn("val2", "float", unique_values=57, minValue=-5.0) - .withColumn("val3", "double", unique_values=93, minValue=1.0, step=0.24) - .withColumn("val4", "double", unique_values=87, random=True, minValue=1.0, step=0.24) - .build() + dg.DataGenerator( + sparkSession=spark, name="test_data_set1", rows=100000, partitions=4, verbose=True, debug=True + ) + .withIdOutput() + .withColumn("val1", "float", unique_values=51, random=True, minValue=1.0) + .withColumn("val2", "float", unique_values=57, minValue=-5.0) + .withColumn("val3", "double", unique_values=93, minValue=1.0, step=0.24) + .withColumn("val4", "double", unique_values=87, random=True, minValue=1.0, step=0.24) + .build() ) df_unique_float2.show() df_unique_float2.createOrReplaceTempView("testUniqueFloats2") - dfResults = spark.sql(""" + dfResults = spark.sql( + """ select count(distinct val1), count(distinct val2), count(distinct val3), count(distinct val4) from testUniqueFloats2 - """"") + """ + "" + ) summary = dfResults.collect()[0] self.assertEqual(summary[0], 51) self.assertEqual(summary[1], 57) @@ -682,17 +733,16 @@ def test_unique_values_float2(self): print("passed") def test_ranged_data_int(self): - ds_data_int = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) - .withIdOutput() - .withColumn("nint", IntegerType(), minValue=1, maxValue=9, step=2) - .withColumn("nint2", IntegerType(), percent_nulls=0.1, minValue=1, maxValue=9, step=2) - .withColumn("nint3", IntegerType(), percent_nulls=0.1, minValue=1, maxValue=9, step=2, - random=True) - .withColumn("sint", ShortType(), minValue=1, maxValue=9, step=2) - .withColumn("sint2", ShortType(), percent_nulls=0.1, minValue=1, maxValue=9, step=2) - .withColumn("sint3", ShortType(), percent_nulls=0.1, minValue=1, maxValue=9, step=2, - random=True) - ) + ds_data_int = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) + .withIdOutput() + .withColumn("nint", IntegerType(), minValue=1, maxValue=9, step=2) + .withColumn("nint2", IntegerType(), percent_nulls=0.1, minValue=1, maxValue=9, step=2) + .withColumn("nint3", IntegerType(), percent_nulls=0.1, minValue=1, maxValue=9, step=2, random=True) + .withColumn("sint", ShortType(), minValue=1, maxValue=9, step=2) + .withColumn("sint2", ShortType(), percent_nulls=0.1, minValue=1, maxValue=9, step=2) + .withColumn("sint3", ShortType(), percent_nulls=0.1, minValue=1, maxValue=9, step=2, random=True) + ) results = ds_data_int.build() @@ -719,15 +769,15 @@ def test_ranged_data_int(self): def test_ranged_data_long(self): # note python 3.6 does not support trailing long literal syntax (i.e 200L) - but all literal ints are long long_min = 3147483651 - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) - .withIdOutput() - .withColumn("lint", LongType(), minValue=long_min, maxValue=long_min + 8, step=2) - .withColumn("lint2", LongType(), minValue=long_min, maxValue=long_min + 8, step=2, - percent_nulls=0.1) - .withColumn("lint3", LongType(), minValue=long_min, maxValue=long_min + 8, step=2, - percent_nulls=0.1, - random=True) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) + .withIdOutput() + .withColumn("lint", LongType(), minValue=long_min, maxValue=long_min + 8, step=2) + .withColumn("lint2", LongType(), minValue=long_min, maxValue=long_min + 8, step=2, percent_nulls=0.1) + .withColumn( + "lint3", LongType(), minValue=long_min, maxValue=long_min + 8, step=2, percent_nulls=0.1, random=True + ) + ) results = testDataSpec.build() @@ -742,15 +792,14 @@ def test_ranged_data_long(self): self.assertSetEqual(set(nint3_values), {None, long_min, long_min + 2, long_min + 4, long_min + 6, long_min + 8}) def test_ranged_data_byte(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) - .withIdOutput() - .withColumn("byte1", ByteType(), minValue=1, maxValue=9, step=2) - .withColumn("byte2", ByteType(), percent_nulls=0.1, minValue=1, maxValue=9, step=2) - .withColumn("byte3", ByteType(), percent_nulls=0.1, minValue=1, maxValue=9, step=2, - random=True) - .withColumn("byte4", ByteType(), percent_nulls=0.1, minValue=-5, maxValue=5, step=2, - random=True) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) + .withIdOutput() + .withColumn("byte1", ByteType(), minValue=1, maxValue=9, step=2) + .withColumn("byte2", ByteType(), percent_nulls=0.1, minValue=1, maxValue=9, step=2) + .withColumn("byte3", ByteType(), percent_nulls=0.1, minValue=1, maxValue=9, step=2, random=True) + .withColumn("byte4", ByteType(), percent_nulls=0.1, minValue=-5, maxValue=5, step=2, random=True) + ) results = testDataSpec.build() @@ -768,17 +817,16 @@ def test_ranged_data_byte(self): self.assertSetEqual(set(byte4_values), {None, -5, -3, -1, 1, 3, 5}) def test_ranged_data_float1(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) - .withIdOutput() - .withColumn("fval", FloatType(), minValue=1.0, maxValue=9.0, step=2.0) - .withColumn("fval2", FloatType(), percent_nulls=0.1, minValue=1.0, maxValue=9.0, step=2.0) - .withColumn("fval3", FloatType(), percent_nulls=0.1, minValue=1.0, maxValue=9.0, step=2.0, - random=True) - .withColumn("dval1", DoubleType(), minValue=1.0, maxValue=9.0, step=2.0) - .withColumn("dval2", DoubleType(), percent_nulls=0.1, minValue=1.0, maxValue=9.0, step=2.0) - .withColumn("dval3", DoubleType(), percent_nulls=0.1, minValue=1.0, maxValue=9.0, step=2.0, - random=True) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) + .withIdOutput() + .withColumn("fval", FloatType(), minValue=1.0, maxValue=9.0, step=2.0) + .withColumn("fval2", FloatType(), percent_nulls=0.1, minValue=1.0, maxValue=9.0, step=2.0) + .withColumn("fval3", FloatType(), percent_nulls=0.1, minValue=1.0, maxValue=9.0, step=2.0, random=True) + .withColumn("dval1", DoubleType(), minValue=1.0, maxValue=9.0, step=2.0) + .withColumn("dval2", DoubleType(), percent_nulls=0.1, minValue=1.0, maxValue=9.0, step=2.0) + .withColumn("dval3", DoubleType(), percent_nulls=0.1, minValue=1.0, maxValue=9.0, step=2.0, random=True) + ) results = testDataSpec.build() @@ -804,17 +852,16 @@ def test_ranged_data_float1(self): self.assertSetEqual(set(double3_values), {None, 1, 3, 5, 7, 9}) def test_ranged_data_float2(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) - .withIdOutput() - .withColumn("fval", FloatType(), minValue=1.5, maxValue=3.5, step=0.5) - .withColumn("fval2", FloatType(), percent_nulls=0.1, minValue=1.5, maxValue=3.5, step=0.5) - .withColumn("fval3", FloatType(), percent_nulls=0.1, minValue=1.5, maxValue=3.5, step=0.5, - random=True) - .withColumn("dval1", DoubleType(), minValue=1.5, maxValue=3.5, step=0.5) - .withColumn("dval2", DoubleType(), percent_nulls=0.1, minValue=1.5, maxValue=3.5, step=0.5) - .withColumn("dval3", DoubleType(), percent_nulls=0.1, minValue=1.5, maxValue=3.5, step=0.5, - random=True) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) + .withIdOutput() + .withColumn("fval", FloatType(), minValue=1.5, maxValue=3.5, step=0.5) + .withColumn("fval2", FloatType(), percent_nulls=0.1, minValue=1.5, maxValue=3.5, step=0.5) + .withColumn("fval3", FloatType(), percent_nulls=0.1, minValue=1.5, maxValue=3.5, step=0.5, random=True) + .withColumn("dval1", DoubleType(), minValue=1.5, maxValue=3.5, step=0.5) + .withColumn("dval2", DoubleType(), percent_nulls=0.1, minValue=1.5, maxValue=3.5, step=0.5) + .withColumn("dval3", DoubleType(), percent_nulls=0.1, minValue=1.5, maxValue=3.5, step=0.5, random=True) + ) results = testDataSpec.build() @@ -846,17 +893,16 @@ def roundIfNotNull(x, scale): def test_ranged_data_float3(self): # when modulo arithmetic does not result in even integer such as ' - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, verbose=True) - .withIdOutput() - .withColumn("fval", FloatType(), minValue=1.5, maxValue=2.5, step=0.3) - .withColumn("fval2", FloatType(), percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3) - .withColumn("fval3", FloatType(), percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, - random=True) - .withColumn("dval1", DoubleType(), minValue=1.5, maxValue=2.5, step=0.3) - .withColumn("dval2", DoubleType(), percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3) - .withColumn("dval3", DoubleType(), percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, - random=True) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, verbose=True) + .withIdOutput() + .withColumn("fval", FloatType(), minValue=1.5, maxValue=2.5, step=0.3) + .withColumn("fval2", FloatType(), percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3) + .withColumn("fval3", FloatType(), percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True) + .withColumn("dval1", DoubleType(), minValue=1.5, maxValue=2.5, step=0.3) + .withColumn("dval2", DoubleType(), percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3) + .withColumn("dval3", DoubleType(), percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True) + ) results = testDataSpec.build() @@ -883,18 +929,18 @@ def test_ranged_data_float3(self): self.assertSetEqual(set(double3_values), {None, 1.5, 1.8, 2.1, 2.4}) def test_ranged_data_decimal1(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) - .withIdOutput() - .withColumn("decimal1", DecimalType(10, 4), minValue=1.0, maxValue=9.0, step=2.0) - .withColumn("decimal2", DecimalType(10, 4), percent_nulls=0.1, minValue=1.0, maxValue=9.0, - step=2.0) - .withColumn("decimal3", DecimalType(10, 4), percent_nulls=0.1, minValue=1.0, maxValue=9.0, - step=2.0, - random=True) - .withColumn("decimal4", DecimalType(10, 4), percent_nulls=0.1, minValue=-5, maxValue=5, - step=2.0, - random=True) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) + .withIdOutput() + .withColumn("decimal1", DecimalType(10, 4), minValue=1.0, maxValue=9.0, step=2.0) + .withColumn("decimal2", DecimalType(10, 4), percent_nulls=0.1, minValue=1.0, maxValue=9.0, step=2.0) + .withColumn( + "decimal3", DecimalType(10, 4), percent_nulls=0.1, minValue=1.0, maxValue=9.0, step=2.0, random=True + ) + .withColumn( + "decimal4", DecimalType(10, 4), percent_nulls=0.1, minValue=-5, maxValue=5, step=2.0, random=True + ) + ) results = testDataSpec.build() @@ -913,10 +959,11 @@ def test_ranged_data_decimal1(self): self.assertSetEqual(set(decimal4_values), {None, -5, -3, -1, 1, 3, 5}) def test_ranged_data_string1(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) - .withIdOutput() - .withColumn("s1", StringType(), minValue=1, maxValue=123, step=1, format="testing %05d >>") - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) + .withIdOutput() + .withColumn("s1", StringType(), minValue=1, maxValue=123, step=1, format="testing %05d >>") + ) results = testDataSpec.build() @@ -926,10 +973,11 @@ def test_ranged_data_string1(self): self.assertSetEqual(set(s1_expected_values), set(s1_values)) def test_ranged_data_string2(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) - .withIdOutput() - .withColumn("s1", StringType(), minValue=10, maxValue=123, step=1, format="testing %05d >>") - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) + .withIdOutput() + .withColumn("s1", StringType(), minValue=10, maxValue=123, step=1, format="testing %05d >>") + ) results = testDataSpec.build() @@ -939,25 +987,25 @@ def test_ranged_data_string2(self): self.assertSetEqual(set(s1_expected_values), set(s1_values)) def test_ranged_data_string3(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) - .withIdOutput() - .withColumn("s1", StringType(), minValue=10, maxValue=123, step=1, - format="testing %05d >>", random=True) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) + .withIdOutput() + .withColumn("s1", StringType(), minValue=10, maxValue=123, step=1, format="testing %05d >>", random=True) + ) results = testDataSpec.build() # check `s1` values s1_expected_values = [f"testing {x:05} >>" for x in range(10, 124)] s1_values = [r[0] for r in results.select("s1").distinct().collect()] - self.assertTrue( set(s1_values).issubset(set(s1_expected_values))) + self.assertTrue(set(s1_values).issubset(set(s1_expected_values))) def test_ranged_data_string4(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) - .withIdOutput() - .withColumn("s1", StringType(), minValue=10, maxValue=123, step=2, - format="testing %05d >>", random=True) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) + .withIdOutput() + .withColumn("s1", StringType(), minValue=10, maxValue=123, step=2, format="testing %05d >>", random=True) + ) results = testDataSpec.build() @@ -969,12 +1017,13 @@ def test_ranged_data_string4(self): self.assertSetEqual(set(s1_expected_values), set(s1_values)) def test_ranged_data_string5(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) - .withIdOutput() - .withColumn("s1", StringType(), minValue=1.5, maxValue=2.5, step=0.3, - format="testing %05.1f >>", - random=True) - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) + .withIdOutput() + .withColumn( + "s1", StringType(), minValue=1.5, maxValue=2.5, step=0.3, format="testing %05.1f >>", random=True + ) + ) results = testDataSpec.build() diff --git a/tests/test_repeatable_data.py b/tests/test_repeatable_data.py index e3d643c3..7283ad31 100644 --- a/tests/test_repeatable_data.py +++ b/tests/test_repeatable_data.py @@ -24,21 +24,23 @@ def mkBasicDataspec(cls, withRandom=False, dist=None, randomSeed=None): if randomSeed is None: dgSpec = dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=cls.row_count) else: - dgSpec = dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=cls.row_count, - randomSeed=randomSeed, randomSeedMethod='hash_fieldname') - - testDataSpec = (dgSpec - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand(42) * 350) * (86400 + 3600)", - numColumns=cls.column_count) - .withColumn("code1", IntegerType(), minValue=100, maxValue=200, random=withRandom) - .withColumn("code2", IntegerType(), minValue=0, maxValue=1000000, - random=withRandom, distribution=dist) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=withRandom) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], - random=withRandom, weights=[9, 1, 1]) - ) + dgSpec = dg.DataGenerator( + sparkSession=spark, + name="test_data_set1", + rows=cls.row_count, + randomSeed=randomSeed, + randomSeedMethod='hash_fieldname', + ) + + testDataSpec = ( + dgSpec.withIdOutput() + .withColumn("r", FloatType(), expr="floor(rand(42) * 350) * (86400 + 3600)", numColumns=cls.column_count) + .withColumn("code1", IntegerType(), minValue=100, maxValue=200, random=withRandom) + .withColumn("code2", IntegerType(), minValue=0, maxValue=1000000, random=withRandom, distribution=dist) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=withRandom) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=withRandom, weights=[9, 1, 1]) + ) return testDataSpec @@ -150,9 +152,9 @@ def test_basic_clone_with_random(self): def test_clone_with_new_column(self): """Test clone method""" - ds1 = (self.mkBasicDataspec(withRandom=True, dist="normal") - .withColumn("another_column", StringType(), values=['a', 'b', 'c'], random=True) - ) + ds1 = self.mkBasicDataspec(withRandom=True, dist="normal").withColumn( + "another_column", StringType(), values=['a', 'b', 'c'], random=True + ) df1 = ds1.build() ds2 = ds1.clone() @@ -162,11 +164,11 @@ def test_clone_with_new_column(self): def test_multiple_base_columns(self): """Test data generation with multiple base columns""" - ds1 = (self.mkBasicDataspec(withRandom=True) - .withColumn("ac1", IntegerType(), baseColumn=['code1', 'code2'], minValue=100, maxValue=200) - .withColumn("ac2", IntegerType(), baseColumn=['code1', 'code2'], - minValue=100, maxValue=200, random=True) - ) + ds1 = ( + self.mkBasicDataspec(withRandom=True) + .withColumn("ac1", IntegerType(), baseColumn=['code1', 'code2'], minValue=100, maxValue=200) + .withColumn("ac2", IntegerType(), baseColumn=['code1', 'code2'], minValue=100, maxValue=200, random=True) + ) df1 = ds1.build() ds2 = ds1.clone() df2 = ds2.build() @@ -175,9 +177,7 @@ def test_multiple_base_columns(self): def test_date_column(self): """Test data generation with date columns""" - ds1 = (self.mkBasicDataspec(withRandom=True) - .withColumn("dt1", DateType(), random=True) - ) + ds1 = self.mkBasicDataspec(withRandom=True).withColumn("dt1", DateType(), random=True) df1 = ds1.build() ds2 = ds1.clone() df2 = ds2.build() @@ -186,9 +186,7 @@ def test_date_column(self): def test_timestamp_column(self): """Test data generation with timestamp columns""" - ds1 = (self.mkBasicDataspec(withRandom=True) - .withColumn("ts1", TimestampType(), random=True) - ) + ds1 = self.mkBasicDataspec(withRandom=True).withColumn("ts1", TimestampType(), random=True) df1 = ds1.build() ds2 = ds1.clone() df2 = ds2.build() @@ -197,9 +195,7 @@ def test_timestamp_column(self): def test_template_column(self): """Test data generation with _template columns""" - ds1 = (self.mkBasicDataspec(withRandom=True) - .withColumn("txt1", "string", template=r"dr_\\v") - ) + ds1 = self.mkBasicDataspec(withRandom=True).withColumn("txt1", "string", template=r"dr_\\v") df1 = ds1.build() ds2 = ds1.clone() df2 = ds2.build() @@ -211,9 +207,7 @@ def test_template_column(self): def test_template_column_random(self): """Test data generation with _template columns""" - ds1 = (self.mkBasicDataspec(withRandom=True) - .withColumn("txt1", "string", template=r"\dr_\v", random=True) - ) + ds1 = self.mkBasicDataspec(withRandom=True).withColumn("txt1", "string", template=r"\dr_\v", random=True) df1 = ds1.build() ds2 = ds1.clone() df2 = ds2.build() @@ -224,13 +218,12 @@ def test_template_column_random(self): self.checkTablesEqual(df1, df2) def test_template_column_random2(self): - """Test data generation with _template columns - - """ - ds1 = (self.mkBasicDataspec(withRandom=True) - .withColumn("txt1", "string", template=r"dr_\v", random=True, escapeSpecialChars=True) - .withColumn("nonRandom", "string", baseColumn="code1") - ) + """Test data generation with _template columns""" + ds1 = ( + self.mkBasicDataspec(withRandom=True) + .withColumn("txt1", "string", template=r"dr_\v", random=True, escapeSpecialChars=True) + .withColumn("nonRandom", "string", baseColumn="code1") + ) df1 = ds1.build() ds2 = ds1.clone() df2 = ds2.build() @@ -254,12 +247,10 @@ def test_template_column_random2(self): self.assertEqual(value0, "dr_0") def test_ILText_column_random2(self): - """Test data generation with _template columns - - """ - ds1 = (self.mkBasicDataspec(withRandom=True) - .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) - ) + """Test data generation with _template columns""" + ds1 = self.mkBasicDataspec(withRandom=True).withColumn( + "paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)) + ) df1 = ds1.build() ds2 = ds1.clone() df2 = ds2.build() @@ -267,12 +258,10 @@ def test_ILText_column_random2(self): self.checkTablesEqual(df1, df2) def test_ILText_column_random3(self): - """Test data generation with _template columns - - """ - ds1 = (self.mkBasicDataspec(withRandom=True, randomSeed=41) - .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) - ) + """Test data generation with _template columns""" + ds1 = self.mkBasicDataspec(withRandom=True, randomSeed=41).withColumn( + "paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)) + ) df1 = ds1.build() ds2 = ds1.clone() df2 = ds2.build() @@ -282,14 +271,15 @@ def test_random_seed_flow(self): default_random_seed = dg.DEFAULT_RANDOM_SEED - pluginDataspec = (dg.DataGenerator(spark, rows=data_rows) - .withColumn("code1", minValue=0, maxValue=100) - .withColumn("code2", minValue=0, maxValue=100, randomSeed=2021) - .withColumn("text", "string", template=r"dr_\\v") - .withColumn("text2", "string", template=r"dr_\\v", random=True) - .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) - .withColumn("paras2", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)), random=True) - ) + pluginDataspec = ( + dg.DataGenerator(spark, rows=data_rows) + .withColumn("code1", minValue=0, maxValue=100) + .withColumn("code2", minValue=0, maxValue=100, randomSeed=2021) + .withColumn("text", "string", template=r"dr_\\v") + .withColumn("text2", "string", template=r"dr_\\v", random=True) + .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) + .withColumn("paras2", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)), random=True) + ) self.assertEqual(pluginDataspec.randomSeed, dg.DEFAULT_RANDOM_SEED) @@ -320,14 +310,15 @@ def test_random_seed_flow2(self): effective_random_seed = dg.DEFAULT_RANDOM_SEED - pluginDataspec = (dg.DataGenerator(spark, rows=data_rows) - .withColumn("code1", minValue=0, maxValue=100) - .withColumn("code2", minValue=0, maxValue=100, random=True) - .withColumn("text", "string", template=r"dr_\\v") - .withColumn("text2", "string", template=r"dr_\\v", random=True) - .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) - .withColumn("paras2", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)), random=True) - ) + pluginDataspec = ( + dg.DataGenerator(spark, rows=data_rows) + .withColumn("code1", minValue=0, maxValue=100) + .withColumn("code2", minValue=0, maxValue=100, random=True) + .withColumn("text", "string", template=r"dr_\\v") + .withColumn("text2", "string", template=r"dr_\\v", random=True) + .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) + .withColumn("paras2", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)), random=True) + ) self.assertEqual(pluginDataspec.randomSeed, effective_random_seed) @@ -347,19 +338,20 @@ def test_random_seed_flow2(self): self.assertEqual(text2Spec.randomSeed, text2Spec.textGenerator.randomSeed) def test_random_seed_flow_explicit_instance(self): - """ Check the explicit random seed is applied to all columns""" + """Check the explicit random seed is applied to all columns""" data_rows = 100 * 1000 effective_random_seed = 1017 - pluginDataspec = (dg.DataGenerator(spark, rows=data_rows, randomSeed=effective_random_seed) - .withColumn("code1", minValue=0, maxValue=100) - .withColumn("code2", minValue=0, maxValue=100, random=True) - .withColumn("text", "string", template=r"dr_\\v") - .withColumn("text2", "string", template=r"dr_\\v", random=True) - .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) - .withColumn("paras2", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)), random=True) - ) + pluginDataspec = ( + dg.DataGenerator(spark, rows=data_rows, randomSeed=effective_random_seed) + .withColumn("code1", minValue=0, maxValue=100) + .withColumn("code2", minValue=0, maxValue=100, random=True) + .withColumn("text", "string", template=r"dr_\\v") + .withColumn("text2", "string", template=r"dr_\\v", random=True) + .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) + .withColumn("paras2", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)), random=True) + ) self.assertEqual(pluginDataspec.randomSeed, effective_random_seed, "dataspec") @@ -387,21 +379,22 @@ def test_random_seed_flow_explicit_instance(self): self.assertEqual(paras2Spec.randomSeed, paras2Spec.textGenerator.randomSeed, "paras2Spec with textGenerator") def test_random_seed_flow_hash_fieldname(self): - """ Check the explicit random seed is applied to all columns""" + """Check the explicit random seed is applied to all columns""" data_rows = 100 * 1000 effective_random_seed = 1017 - pluginDataspec = (dg.DataGenerator(spark, rows=data_rows, - randomSeed=effective_random_seed, - randomSeedMethod=dg.RANDOM_SEED_HASH_FIELD_NAME) - .withColumn("code1", minValue=0, maxValue=100) - .withColumn("code2", minValue=0, maxValue=100, random=True) - .withColumn("text", "string", template=r"dr_\\v") - .withColumn("text2", "string", template=r"dr_\\v", random=True) - .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) - .withColumn("paras2", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)), random=True) - ) + pluginDataspec = ( + dg.DataGenerator( + spark, rows=data_rows, randomSeed=effective_random_seed, randomSeedMethod=dg.RANDOM_SEED_HASH_FIELD_NAME + ) + .withColumn("code1", minValue=0, maxValue=100) + .withColumn("code2", minValue=0, maxValue=100, random=True) + .withColumn("text", "string", template=r"dr_\\v") + .withColumn("text2", "string", template=r"dr_\\v", random=True) + .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) + .withColumn("paras2", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)), random=True) + ) self.assertEqual(pluginDataspec.randomSeed, effective_random_seed, "dataspec") @@ -425,21 +418,23 @@ def test_random_seed_flow_hash_fieldname(self): self.assertEqual(paras2Spec.randomSeed, paras2Spec.textGenerator.randomSeed, "paras2Spec with textGenerator") def test_random_seed_flow3_true_random(self): - """ Check the explicit random seed (-1) is applied to all columns""" + """Check the explicit random seed (-1) is applied to all columns""" data_rows = 100 * 1000 effective_random_seed = -1 explicitRandomSeed = 41 - pluginDataspec = (dg.DataGenerator(spark, rows=data_rows, randomSeed=effective_random_seed) - .withColumn("code1", minValue=0, maxValue=100) - .withColumn("code2", minValue=0, maxValue=100, random=True) - .withColumn("text", "string", template=r"dr_\\v") - .withColumn("text2", "string", template=r"dr_\\v", random=True, randomSeed=explicitRandomSeed) - .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) - .withColumn("paras2", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)), random=True, - randomSeedMethod="fixed") - ) + pluginDataspec = ( + dg.DataGenerator(spark, rows=data_rows, randomSeed=effective_random_seed) + .withColumn("code1", minValue=0, maxValue=100) + .withColumn("code2", minValue=0, maxValue=100, random=True) + .withColumn("text", "string", template=r"dr_\\v") + .withColumn("text2", "string", template=r"dr_\\v", random=True, randomSeed=explicitRandomSeed) + .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) + .withColumn( + "paras2", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)), random=True, randomSeedMethod="fixed" + ) + ) self.assertEqual(pluginDataspec.randomSeed, effective_random_seed, "dataspec") @@ -471,16 +466,17 @@ def test_random_seed_flow3a(self): effective_random_seed = 1017 - pluginDataspec = (dg.DataGenerator(spark, rows=data_rows, - randomSeed=effective_random_seed, - randomSeedMethod=dg.RANDOM_SEED_FIXED) - .withColumn("code1", minValue=0, maxValue=100) - .withColumn("code2", minValue=0, maxValue=100, random=True) - .withColumn("text", "string", template=r"dr_\\v") - .withColumn("text2", "string", template=r"dr_\\v", random=True) - .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) - .withColumn("paras2", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)), random=True) - ) + pluginDataspec = ( + dg.DataGenerator( + spark, rows=data_rows, randomSeed=effective_random_seed, randomSeedMethod=dg.RANDOM_SEED_FIXED + ) + .withColumn("code1", minValue=0, maxValue=100) + .withColumn("code2", minValue=0, maxValue=100, random=True) + .withColumn("text", "string", template=r"dr_\\v") + .withColumn("text2", "string", template=r"dr_\\v", random=True) + .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) + .withColumn("paras2", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)), random=True) + ) self.assertEqual(pluginDataspec.randomSeed, effective_random_seed, "dataspec") @@ -503,14 +499,15 @@ def test_seed_flow4(self): effective_random_seed = dg.RANDOM_SEED_RANDOM - pluginDataspec = (dg.DataGenerator(spark, rows=data_rows, randomSeed=effective_random_seed) - .withColumn("code1", minValue=0, maxValue=100) - .withColumn("code2", minValue=0, maxValue=100, random=True) - .withColumn("text", "string", template=r"dr_\\v") - .withColumn("text2", "string", template=r"dr_\\v", random=True) - .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) - .withColumn("paras2", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)), random=True) - ) + pluginDataspec = ( + dg.DataGenerator(spark, rows=data_rows, randomSeed=effective_random_seed) + .withColumn("code1", minValue=0, maxValue=100) + .withColumn("code2", minValue=0, maxValue=100, random=True) + .withColumn("text", "string", template=r"dr_\\v") + .withColumn("text2", "string", template=r"dr_\\v", random=True) + .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6))) + .withColumn("paras2", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6)), random=True) + ) self.assertEqual(pluginDataspec.randomSeed, effective_random_seed, "dataspec") diff --git a/tests/test_schema_parser.py b/tests/test_schema_parser.py index b1270099..b925c405 100644 --- a/tests/test_schema_parser.py +++ b/tests/test_schema_parser.py @@ -1,8 +1,24 @@ import logging import pytest -from pyspark.sql.types import LongType, FloatType, IntegerType, StringType, DoubleType, BooleanType, ShortType, \ - TimestampType, DateType, DecimalType, ByteType, BinaryType, ArrayType, MapType, StructType, StructField +from pyspark.sql.types import ( + LongType, + FloatType, + IntegerType, + StringType, + DoubleType, + BooleanType, + ShortType, + TimestampType, + DateType, + DecimalType, + ByteType, + BinaryType, + ArrayType, + MapType, + StructType, + StructField, +) import dbldatagen as dg @@ -17,67 +33,89 @@ def setupLogging(): class TestSchemaParser: - @pytest.mark.parametrize("typeDefn, expectedTypeDefn", - [("byte", ByteType()), - ("tinyint", ByteType()), - ("short", ShortType()), - ("smallint", ShortType()), - ("int", IntegerType()), - ("integer", IntegerType()), - ("long", LongType()), - ("LONG", LongType()), - ("bigint", LongType()), - ("date", DateType()), - ("binary", BinaryType()), - ("timestamp", TimestampType()), - ("bool", BooleanType()), - ("boolean", BooleanType()), - ("string", StringType()), - ("char(10)", StringType()), - ("nvarchar(14)", StringType()), - ("nvarchar", StringType()), - ("varchar", StringType()), - ("varchar(10)", StringType()) - ]) + @pytest.mark.parametrize( + "typeDefn, expectedTypeDefn", + [ + ("byte", ByteType()), + ("tinyint", ByteType()), + ("short", ShortType()), + ("smallint", ShortType()), + ("int", IntegerType()), + ("integer", IntegerType()), + ("long", LongType()), + ("LONG", LongType()), + ("bigint", LongType()), + ("date", DateType()), + ("binary", BinaryType()), + ("timestamp", TimestampType()), + ("bool", BooleanType()), + ("boolean", BooleanType()), + ("string", StringType()), + ("char(10)", StringType()), + ("nvarchar(14)", StringType()), + ("nvarchar", StringType()), + ("varchar", StringType()), + ("varchar(10)", StringType()), + ], + ) def test_primitive_type_parser(self, typeDefn, expectedTypeDefn, setupLogging): output_type = dg.SchemaParser.columnTypeFromString(typeDefn) assert output_type == expectedTypeDefn, f"Expect output type {output_type} to match {expectedTypeDefn}" - @pytest.mark.parametrize("typeDefn, expectedTypeDefn", - [("float", FloatType()), - ("real", FloatType()), - ("double", DoubleType()), - ("decimal", DecimalType(10, 0)), - ("decimal(11)", DecimalType(11, 0)), - ("decimal(15,3)", DecimalType(15, 3)), - ]) + @pytest.mark.parametrize( + "typeDefn, expectedTypeDefn", + [ + ("float", FloatType()), + ("real", FloatType()), + ("double", DoubleType()), + ("decimal", DecimalType(10, 0)), + ("decimal(11)", DecimalType(11, 0)), + ("decimal(15,3)", DecimalType(15, 3)), + ], + ) def test_numeric_type_parser(self, typeDefn, expectedTypeDefn, setupLogging): output_type = dg.SchemaParser.columnTypeFromString(typeDefn) assert output_type == expectedTypeDefn, f"Expect output type {output_type} to match {expectedTypeDefn}" - @pytest.mark.parametrize("typeDefn, expectedTypeDefn", - [("array", ArrayType(IntegerType())), - ("array>", ArrayType(ArrayType(StringType()))), - ("map", MapType(StringType(), IntegerType())), - ("struct", - StructType([StructField("a", BinaryType()), StructField("b", IntegerType()), - StructField("c", FloatType())])), - ("struct", - StructType([StructField('event_type', StringType()), - StructField('event_ts', TimestampType())])) - ]) + @pytest.mark.parametrize( + "typeDefn, expectedTypeDefn", + [ + ("array", ArrayType(IntegerType())), + ("array>", ArrayType(ArrayType(StringType()))), + ("map", MapType(StringType(), IntegerType())), + ( + "struct", + StructType( + [StructField("a", BinaryType()), StructField("b", IntegerType()), StructField("c", FloatType())] + ), + ), + ( + "struct", + StructType([StructField('event_type', StringType()), StructField('event_ts', TimestampType())]), + ), + ], + ) def test_complex_type_parser(self, typeDefn, expectedTypeDefn, setupLogging): output_type = dg.SchemaParser.columnTypeFromString(typeDefn) assert output_type == expectedTypeDefn, f"Expect output type {output_type} to match {expectedTypeDefn}" - @pytest.mark.parametrize("typeDefn", - ["decimal(15,3, 3)", "array", "decimal()", - "interval", "array", "map", - "struct", "binary_float" - ]) + @pytest.mark.parametrize( + "typeDefn", + [ + "decimal(15,3, 3)", + "array", + "decimal()", + "interval", + "array", + "map", + "struct", + "binary_float", + ], + ) def test_parser_exceptions(self, typeDefn, setupLogging): with pytest.raises(Exception) as e_info: output_type = dg.SchemaParser.columnTypeFromString(typeDefn) @@ -121,20 +159,30 @@ def test_table_definition_parser(self, setupLogging): assert "name" in schema3.fieldNames() assert "age" in schema3.fieldNames() - @pytest.mark.parametrize("sqlExpr, expectedText", - [("named_struct('name', city_name, 'id', city_id, 'population', city_pop)", - "named_struct(' ', city_name, ' ', city_id, ' ', city_pop)"), - ("named_struct('name', `city 2`, 'id', city_id, 'population', city_pop)", - "named_struct(' ', `city 2`, ' ', city_id, ' ', city_pop)"), - ("named_struct('`name 1`', `city 2`, 'id', city_id, 'population', city_pop)", - "named_struct(' ', `city 2`, ' ', city_id, ' ', city_pop)"), - ("named_struct('`name 1`', city, 'id', city_id, 'population', city_pop)", - "named_struct(' ', city, ' ', city_id, ' ', city_pop)"), - ("cast(10 as decimal(10)", - "cast(10 as decimal(10)"), - (" ", " "), - ("", ""), - ]) + @pytest.mark.parametrize( + "sqlExpr, expectedText", + [ + ( + "named_struct('name', city_name, 'id', city_id, 'population', city_pop)", + "named_struct(' ', city_name, ' ', city_id, ' ', city_pop)", + ), + ( + "named_struct('name', `city 2`, 'id', city_id, 'population', city_pop)", + "named_struct(' ', `city 2`, ' ', city_id, ' ', city_pop)", + ), + ( + "named_struct('`name 1`', `city 2`, 'id', city_id, 'population', city_pop)", + "named_struct(' ', `city 2`, ' ', city_id, ' ', city_pop)", + ), + ( + "named_struct('`name 1`', city, 'id', city_id, 'population', city_pop)", + "named_struct(' ', city, ' ', city_id, ' ', city_pop)", + ), + ("cast(10 as decimal(10)", "cast(10 as decimal(10)"), + (" ", " "), + ("", ""), + ], + ) def test_sql_expression_cleanser(self, sqlExpr, expectedText): newSql = dg.SchemaParser._cleanseSQL(sqlExpr) print(newSql) @@ -142,19 +190,26 @@ def test_sql_expression_cleanser(self, sqlExpr, expectedText): assert newSql == expectedText - @pytest.mark.parametrize("sqlExpr, expectedReferences, filterColumns", - [("named_struct('name', city_name, 'id', city_id, 'population', city_pop)", - ['named_struct', 'city_name', 'city_id', 'city_pop'], - None), - ("named_struct('name', city_name, 'id', city_id, 'population', city_pop)", - ['city_name', 'city_pop'], - ['city_name', 'city_pop']), - ("cast(10 as decimal(10)", ['cast', 'as', 'decimal'], None), - ("cast(x as decimal(10)", ['x'], ['x']), - ("cast(`city 2` as decimal(10)", ['cast', 'city 2', 'as', 'decimal'], None), - (" ", [], None), - ("", [], None), - ]) + @pytest.mark.parametrize( + "sqlExpr, expectedReferences, filterColumns", + [ + ( + "named_struct('name', city_name, 'id', city_id, 'population', city_pop)", + ['named_struct', 'city_name', 'city_id', 'city_pop'], + None, + ), + ( + "named_struct('name', city_name, 'id', city_id, 'population', city_pop)", + ['city_name', 'city_pop'], + ['city_name', 'city_pop'], + ), + ("cast(10 as decimal(10)", ['cast', 'as', 'decimal'], None), + ("cast(x as decimal(10)", ['x'], ['x']), + ("cast(`city 2` as decimal(10)", ['cast', 'city 2', 'as', 'decimal'], None), + (" ", [], None), + ("", [], None), + ], + ) def test_sql_expression_parser(self, sqlExpr, expectedReferences, filterColumns): references = dg.SchemaParser.columnsReferencesFromSQLString(sqlExpr, filterItems=filterColumns) assert references is not None diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 7ca18c20..6691d1af 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -22,25 +22,26 @@ def checkSchemaEquality(self, schema1, schema2): for c1, c2 in zip(schema1.fields, schema2.fields): self.assertEqual(c1.name, c2.name, msg=f"{c1.name} != {c2.name}") - self.assertEqual(c1.dataType, c2.dataType, - msg=f"{c1.name}.datatype ({c1.dataType}) != {c2.name}.datatype ({c2.dataType})") + self.assertEqual( + c1.dataType, + c2.dataType, + msg=f"{c1.name}.datatype ({c1.dataType}) != {c2.name}.datatype ({c2.dataType})", + ) def test_generate_table_script(self): tbl_name = "scripted_table1" spark.sql(f"drop table if exists {tbl_name}") - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=self.row_count, - partitions=4) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=self.column_count) - .withColumn("code1", IntegerType(), minValue=100, maxValue=200) - .withColumn("code2", IntegerType(), minValue=0, maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=self.row_count, partitions=4) + .withIdOutput() + .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", numColumns=self.column_count) + .withColumn("code1", IntegerType(), minValue=100, maxValue=200) + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) creation_script = testDataSpec.scriptTable(name=tbl_name, tableFormat="parquet") @@ -66,18 +67,16 @@ def test_generate_table_script2(self): tbl_name = "scripted_table1" spark.sql(f"drop table if exists {tbl_name}") - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set2", rows=self.row_count, - partitions=4) - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=self.column_count) - .withColumn("code1", IntegerType(), minValue=100, maxValue=200) - .withColumn("code2", IntegerType(), minValue=0, maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set2", rows=self.row_count, partitions=4) + .withIdOutput() + .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", numColumns=self.column_count) + .withColumn("code1", IntegerType(), minValue=100, maxValue=200) + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) creation_script = testDataSpec.scriptTable(name=tbl_name, tableFormat="parquet", location="/tmp/test") diff --git a/tests/test_serialization.py b/tests/test_serialization.py index bb1e02cc..46f010bd 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -10,50 +10,126 @@ class TestSerialization: - @pytest.mark.parametrize("expectation, columns", [ - (does_not_raise(), [ - {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100}, - {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0}, - {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True} - ]), - (does_not_raise(), [ - {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100, "step": 2, "random": True}, - {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0, "step": 1.5}, - {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True} - ]), - (does_not_raise(), [ - {"colName": "col1", "colType": "string", "template": r"\w.\w@\w.com|\w@\w.co.u\k"}, - {"colName": "col2", "colType": "string", "text": {"kind": "TemplateGenerator", "template": "ddd-ddd-dddd"}}, - {"colName": "col3", "colType": "string", "text": {"kind": "TemplateGenerator", - "template": r"\w \w|\w \w \w|\w \a. \w", - "escapeSpecialChars": True, - "extendedWordList": ["red", "blue", "yellow"]}}, - {"colName": "col4", "colType": "string", "text": {"kind": "ILText", "paragraphs": 2, - "sentences": 4, "words": 10}} - ]), - (does_not_raise(), [ - {"colName": "col1", "colType": "date", "dataRange": {"kind": "DateRange", "begin": "2025-01-01 00:00:00", - "end": "2025-12-31 00:00:00", "interval": "days=1"}}, - {"colName": "col2", "colType": "double", "dataRange": {"kind": "NRange", "minValue": 0.0, - "maxValue": 10.0, "step": 0.1}} - ]), - (does_not_raise(), [ - {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100, "random": True, - "distribution": {"kind": "Gamma", "shape": 1.0, "scale": 2.0}}, - {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 1.0, "random": True, - "distribution": {"kind": "Beta", "alpha": 2, "beta": 5}}, - {"colName": "col3", "colType": "float", "minValue": 0, "maxValue": 10000, "random": True, - "distribution": {"kind": "Exponential", "rate": 1.5}}, - {"colName": "col4", "colType": "int", "minValue": 0, "maxValue": 100, "random": True, - "distribution": {"kind": "Normal", "mean": 50.0, "stddev": 2.0}}, - ]), - (pytest.raises(NotImplementedError), [ # Testing serialization error with PyfuncText - {"colName": "col1", "colType": "string", "text": {"kind": "PyfuncText", "fn": "lambda x: x.trim()"}} - ]), - (pytest.raises(ValueError), [ # Testing serialization error with a bad "kind" - {"colName": "col1", "colType": "string", "text": {"kind": "InvalidTextFactory", "property": "value"}} - ]) - ]) + @pytest.mark.parametrize( + "expectation, columns", + [ + ( + does_not_raise(), + [ + {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100}, + {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0}, + {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True}, + ], + ), + ( + does_not_raise(), + [ + {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100, "step": 2, "random": True}, + {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0, "step": 1.5}, + {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True}, + ], + ), + ( + does_not_raise(), + [ + {"colName": "col1", "colType": "string", "template": r"\w.\w@\w.com|\w@\w.co.u\k"}, + { + "colName": "col2", + "colType": "string", + "text": {"kind": "TemplateGenerator", "template": "ddd-ddd-dddd"}, + }, + { + "colName": "col3", + "colType": "string", + "text": { + "kind": "TemplateGenerator", + "template": r"\w \w|\w \w \w|\w \a. \w", + "escapeSpecialChars": True, + "extendedWordList": ["red", "blue", "yellow"], + }, + }, + { + "colName": "col4", + "colType": "string", + "text": {"kind": "ILText", "paragraphs": 2, "sentences": 4, "words": 10}, + }, + ], + ), + ( + does_not_raise(), + [ + { + "colName": "col1", + "colType": "date", + "dataRange": { + "kind": "DateRange", + "begin": "2025-01-01 00:00:00", + "end": "2025-12-31 00:00:00", + "interval": "days=1", + }, + }, + { + "colName": "col2", + "colType": "double", + "dataRange": {"kind": "NRange", "minValue": 0.0, "maxValue": 10.0, "step": 0.1}, + }, + ], + ), + ( + does_not_raise(), + [ + { + "colName": "col1", + "colType": "int", + "minValue": 0, + "maxValue": 100, + "random": True, + "distribution": {"kind": "Gamma", "shape": 1.0, "scale": 2.0}, + }, + { + "colName": "col2", + "colType": "float", + "minValue": 0.0, + "maxValue": 1.0, + "random": True, + "distribution": {"kind": "Beta", "alpha": 2, "beta": 5}, + }, + { + "colName": "col3", + "colType": "float", + "minValue": 0, + "maxValue": 10000, + "random": True, + "distribution": {"kind": "Exponential", "rate": 1.5}, + }, + { + "colName": "col4", + "colType": "int", + "minValue": 0, + "maxValue": 100, + "random": True, + "distribution": {"kind": "Normal", "mean": 50.0, "stddev": 2.0}, + }, + ], + ), + ( + pytest.raises(NotImplementedError), + [ # Testing serialization error with PyfuncText + {"colName": "col1", "colType": "string", "text": {"kind": "PyfuncText", "fn": "lambda x: x.trim()"}} + ], + ), + ( + pytest.raises(ValueError), + [ # Testing serialization error with a bad "kind" + { + "colName": "col1", + "colType": "string", + "text": {"kind": "InvalidTextFactory", "property": "value"}, + } + ], + ), + ], + ) def test_column_definitions_from_dict(self, columns, expectation): with expectation: # Test the options set on the ColumnGenerationSpecs: @@ -85,108 +161,208 @@ def test_column_definitions_from_dict(self, columns, expectation): df_from_dicts = gen_from_dicts.build() assert df_from_dicts.columns == [column["colName"] for column in columns] - @pytest.mark.parametrize("expectation, constraints", [ - (does_not_raise(), [ - {"kind": "LiteralRange", "columns": ["col1"], "lowValue": -1000, "highValue": 1000, "strict": True}, - {"kind": "PositiveValues", "columns": ["col1", "col2"], "strict": True} - ]), - (does_not_raise(), [ - {"kind": "LiteralRange", "columns": ["col1"], "lowValue": -1000, "highValue": 1000, "strict": False}, - {"kind": "PositiveValues", "columns": ["col1", "col2"], "strict": True}, - {"kind": "SqlExpr", "expr": "col1 > 0"}, - {"kind": "LiteralRelation", "columns": ["col2"], "relation": "<>", "value": "0"} - ]), - (pytest.raises(ValueError), [ # Testing an invalid "relation" value - {"kind": "LiteralRange", "columns": ["col1"], "lowValue": -1000, "highValue": 1000, "strict": True}, - {"kind": "PositiveValues", "columns": ["col1", "col2"], "strict": True}, - {"kind": "SqlExpr", "expr": "col1 > 0"}, - {"kind": "LiteralRelation", "columns": ["col2"], "relation": "+", "value": "0"} - ]), - (pytest.raises(TypeError), [ # Testing an invalid "kind" value - {"kind": "LiteralRange", "columns": ["col1"], "lowValue": -1000, "highValue": 1000, "strict": False}, - {"kind": "PositiveValues", "columns": ["col1", "col2"], "strict": True}, - {"kind": "SqlExpr", "expr": "col1 > 0"}, - {"kind": "InvalidConstraintType", "columns": ["col2"], "value": "0"} - ]), - (does_not_raise(), [ - {"kind": "LiteralRange", "columns": ["col1"], "lowValue": -1000, "highValue": 1000, "strict": True}, - {"kind": "NegativeValues", "columns": ["col1", "col2"], "strict": False}, - {"kind": "ChainedRelation", "columns": ["col1", "col2"], "relation": ">"}, - {"kind": "RangedValues", "columns": ["col2"], "lowValue": 0, "highValue": 100, "strict": True}, - {"kind": "UniqueCombinations", "columns": ["col1", "col2"]} - ]), - ]) + @pytest.mark.parametrize( + "expectation, constraints", + [ + ( + does_not_raise(), + [ + {"kind": "LiteralRange", "columns": ["col1"], "lowValue": -1000, "highValue": 1000, "strict": True}, + {"kind": "PositiveValues", "columns": ["col1", "col2"], "strict": True}, + ], + ), + ( + does_not_raise(), + [ + { + "kind": "LiteralRange", + "columns": ["col1"], + "lowValue": -1000, + "highValue": 1000, + "strict": False, + }, + {"kind": "PositiveValues", "columns": ["col1", "col2"], "strict": True}, + {"kind": "SqlExpr", "expr": "col1 > 0"}, + {"kind": "LiteralRelation", "columns": ["col2"], "relation": "<>", "value": "0"}, + ], + ), + ( + pytest.raises(ValueError), + [ # Testing an invalid "relation" value + {"kind": "LiteralRange", "columns": ["col1"], "lowValue": -1000, "highValue": 1000, "strict": True}, + {"kind": "PositiveValues", "columns": ["col1", "col2"], "strict": True}, + {"kind": "SqlExpr", "expr": "col1 > 0"}, + {"kind": "LiteralRelation", "columns": ["col2"], "relation": "+", "value": "0"}, + ], + ), + ( + pytest.raises(TypeError), + [ # Testing an invalid "kind" value + { + "kind": "LiteralRange", + "columns": ["col1"], + "lowValue": -1000, + "highValue": 1000, + "strict": False, + }, + {"kind": "PositiveValues", "columns": ["col1", "col2"], "strict": True}, + {"kind": "SqlExpr", "expr": "col1 > 0"}, + {"kind": "InvalidConstraintType", "columns": ["col2"], "value": "0"}, + ], + ), + ( + does_not_raise(), + [ + {"kind": "LiteralRange", "columns": ["col1"], "lowValue": -1000, "highValue": 1000, "strict": True}, + {"kind": "NegativeValues", "columns": ["col1", "col2"], "strict": False}, + {"kind": "ChainedRelation", "columns": ["col1", "col2"], "relation": ">"}, + {"kind": "RangedValues", "columns": ["col2"], "lowValue": 0, "highValue": 100, "strict": True}, + {"kind": "UniqueCombinations", "columns": ["col1", "col2"]}, + ], + ), + ], + ) def test_constraint_definitions_from_dict(self, constraints, expectation): with expectation: # Test the options set on the ColumnGenerationSpecs: columns = [ {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100}, {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0}, - {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True} + {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True}, ] - gen_from_dicts = dg.DataGenerator(rows=100, partitions=1) \ - ._loadColumnsFromInitializationDicts(columns) \ + gen_from_dicts = ( + dg.DataGenerator(rows=100, partitions=1) + ._loadColumnsFromInitializationDicts(columns) ._loadConstraintsFromInitializationDicts(constraints) + ) constraint_specs = [constraint._toInitializationDict() for constraint in gen_from_dicts.constraints] for constraint in constraints: assert constraint in constraint_specs - @pytest.mark.parametrize("expectation, options", [ - (does_not_raise(), - {"name": "test_generator", "rows": 1000, - "columns": [ - {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100}, - {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0}, - {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True}] - }), - (does_not_raise(), - {"name": "test_generator", "rows": 10000, "randomSeed": 42, - "columns": [ - {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100, "step": 2, "random": True}, - {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0, "step": 1.5}, - {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True}] - }), - (does_not_raise(), - {"name": "test_generator", "rows": 10000, "randomSeed": 42, - "columns": [ - {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100, "step": 2, "random": True}, - {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0, "step": 1.5}, - {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True}], - "constraints": [ - {"kind": "LiteralRange", "columns": ["col1"], "lowValue": -1000, "highValue": 1000, "strict": True}, - {"kind": "PositiveValues", "columns": ["col1", "col2"], "strict": True}, - {"kind": "SqlExpr", "expr": "col1 > 0"}, - {"kind": "LiteralRelation", "columns": ["col2"], "relation": "<>", "value": "0"}] - }), - (does_not_raise(), # Testing a dictionary missing a "generator" object - {"columns": [ - {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100, "step": 2, "random": True}, - {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0, "step": 1.5}, - {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True}] - }), - (pytest.raises(TypeError), # Testing an invalid "kind" value - {"name": "test_generator", "rows": 10000, "randomSeed": 42, - "columns": [ - {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100, "step": 2, "random": True}, - {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0, "step": 1.5}, - {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True}], - "constraints": [ - {"kind": "LiteralRange", "columns": ["col1"], "lowValue": -1000, "highValue": 1000, "strict": True}, - {"kind": "PositiveValues", "columns": ["col1", "col2"], "strict": True}, - {"kind": "SqlExpr", "expr": "col1 > 0"}, - {"kind": "InvalidConstraintType", "columns": ["col2"], "value": 0}] - }), - ]) + @pytest.mark.parametrize( + "expectation, options", + [ + ( + does_not_raise(), + { + "name": "test_generator", + "rows": 1000, + "columns": [ + {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100}, + {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0}, + {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True}, + ], + }, + ), + ( + does_not_raise(), + { + "name": "test_generator", + "rows": 10000, + "randomSeed": 42, + "columns": [ + { + "colName": "col1", + "colType": "int", + "minValue": 0, + "maxValue": 100, + "step": 2, + "random": True, + }, + {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0, "step": 1.5}, + {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True}, + ], + }, + ), + ( + does_not_raise(), + { + "name": "test_generator", + "rows": 10000, + "randomSeed": 42, + "columns": [ + { + "colName": "col1", + "colType": "int", + "minValue": 0, + "maxValue": 100, + "step": 2, + "random": True, + }, + {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0, "step": 1.5}, + {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True}, + ], + "constraints": [ + { + "kind": "LiteralRange", + "columns": ["col1"], + "lowValue": -1000, + "highValue": 1000, + "strict": True, + }, + {"kind": "PositiveValues", "columns": ["col1", "col2"], "strict": True}, + {"kind": "SqlExpr", "expr": "col1 > 0"}, + {"kind": "LiteralRelation", "columns": ["col2"], "relation": "<>", "value": "0"}, + ], + }, + ), + ( + does_not_raise(), # Testing a dictionary missing a "generator" object + { + "columns": [ + { + "colName": "col1", + "colType": "int", + "minValue": 0, + "maxValue": 100, + "step": 2, + "random": True, + }, + {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0, "step": 1.5}, + {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True}, + ] + }, + ), + ( + pytest.raises(TypeError), # Testing an invalid "kind" value + { + "name": "test_generator", + "rows": 10000, + "randomSeed": 42, + "columns": [ + { + "colName": "col1", + "colType": "int", + "minValue": 0, + "maxValue": 100, + "step": 2, + "random": True, + }, + {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0, "step": 1.5}, + {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": True}, + ], + "constraints": [ + { + "kind": "LiteralRange", + "columns": ["col1"], + "lowValue": -1000, + "highValue": 1000, + "strict": True, + }, + {"kind": "PositiveValues", "columns": ["col1", "col2"], "strict": True}, + {"kind": "SqlExpr", "expr": "col1 > 0"}, + {"kind": "InvalidConstraintType", "columns": ["col2"], "value": 0}, + ], + }, + ), + ], + ) def test_generator_from_dict(self, options, expectation): with expectation: # Test the options set on the DataGenerator: gen_from_dicts = dg.DataGenerator.loadFromInitializationDict(options) - generator = { - k: v for k, v in options.items() - if not isinstance(v, list) - and not isinstance(v, dict) - } + generator = {k: v for k, v in options.items() if not isinstance(v, list) and not isinstance(v, dict)} for key in generator: assert gen_from_dicts.saveToInitializationDict()[key] == generator[key] @@ -199,10 +375,7 @@ def test_generator_from_dict(self, options, expectation): # Test the options set on the Constraints: constraints = options.get("constraints", []) - constraint_specs = [ - constraint._toInitializationDict() - for constraint in gen_from_dicts.constraints - ] + constraint_specs = [constraint._toInitializationDict() for constraint in gen_from_dicts.constraints] for constraint in constraints: assert constraint in constraint_specs @@ -210,23 +383,30 @@ def test_generator_from_dict(self, options, expectation): df_from_dicts = gen_from_dicts.build() assert df_from_dicts.columns == ["col1", "col2", "col3"] - @pytest.mark.parametrize("expectation, json_options", [ - (does_not_raise(), - '''{"name": "test_generator", "rows": 1000, + @pytest.mark.parametrize( + "expectation, json_options", + [ + ( + does_not_raise(), + '''{"name": "test_generator", "rows": 1000, "columns": [ {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100}, {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0}, {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": true}] - }'''), - (does_not_raise(), - '''{"name": "test_generator", "rows": 10000, "randomSeed": 42, + }''', + ), + ( + does_not_raise(), + '''{"name": "test_generator", "rows": 10000, "randomSeed": 42, "columns": [ {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100, "step": 2, "random": true}, {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0, "step": 1.5}, {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": true}] - }'''), - (does_not_raise(), - '''{"name": "test_generator", "rows": 10000, "randomSeed": 42, + }''', + ), + ( + does_not_raise(), + '''{"name": "test_generator", "rows": 10000, "randomSeed": 42, "columns": [ {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100, "step": 2, "random": true}, {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0, "step": 1.5}, @@ -236,15 +416,19 @@ def test_generator_from_dict(self, options, expectation): {"kind": "PositiveValues", "columns": ["col1", "col2"], "strict": true}, {"kind": "SqlExpr", "expr": "col1 > 0"}, {"kind": "LiteralRelation", "columns": ["col2"], "relation": "<>", "value": 0}] - }'''), - (does_not_raise(), # Testing a JSON object missing the "generator" key - '''{"columns": [ + }''', + ), + ( + does_not_raise(), # Testing a JSON object missing the "generator" key + '''{"columns": [ {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100, "step": 2, "random": true}, {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0, "step": 1.5}, {"colName": "col3", "colType": "string", "values": ["a", "b", "c"], "random": true}] - }'''), - (pytest.raises(TypeError), # Testing an invalid "kind" value - '''{"name": "test_generator", "rows": 10000, "randomSeed": 42, + }''', + ), + ( + pytest.raises(TypeError), # Testing an invalid "kind" value + '''{"name": "test_generator", "rows": 10000, "randomSeed": 42, "columns": [ {"colName": "col1", "colType": "int", "minValue": 0, "maxValue": 100, "step": 2, "random": true}, {"colName": "col2", "colType": "float", "minValue": 0.0, "maxValue": 100.0, "step": 1.5}, @@ -254,18 +438,16 @@ def test_generator_from_dict(self, options, expectation): {"kind": "PositiveValues", "columns": ["col1", "col2"], "strict": true}, {"kind": "SqlExpr", "expr": "col1 > 0"}, {"kind": "InvalidConstraintType", "columns": ["col2"], "value": 0}] - }'''), - ]) + }''', + ), + ], + ) def test_generator_from_json(self, json_options, expectation): options = json.loads(json_options) with expectation: # Test the options set on the DataGenerator: gen_from_dicts = dg.DataGenerator.loadFromJson(json_options) - generator = { - k: v for k, v in options.items() - if not isinstance(v, list) - and not isinstance(v, dict) - } + generator = {k: v for k, v in options.items() if not isinstance(v, list) and not isinstance(v, dict)} for key in generator: assert gen_from_dicts.saveToInitializationDict()[key] == generator[key] @@ -278,10 +460,7 @@ def test_generator_from_json(self, json_options, expectation): # Test the options set on the Constraints: constraints = options.get("constraints", []) - constraint_specs = [ - constraint._toInitializationDict() - for constraint in gen_from_dicts.constraints - ] + constraint_specs = [constraint._toInitializationDict() for constraint in gen_from_dicts.constraints] for constraint in constraints: assert constraint in constraint_specs @@ -309,7 +488,7 @@ def test_from_options(self): "kind": "ColumnGenerationSpec", "name": "col1", "colType": "double", - "dataRange": {"kind": "NRange", "minValue": 0.0, "maxValue": 100.0, "step": 0.1} + "dataRange": {"kind": "NRange", "minValue": 0.0, "maxValue": 100.0, "step": 0.1}, } column = ColumnGenerationSpec._fromInitializationDict(options) column_dict = column._toInitializationDict() diff --git a/tests/test_serverless.py b/tests/test_serverless.py index a20085cc..f0f48798 100644 --- a/tests/test_serverless.py +++ b/tests/test_serverless.py @@ -24,9 +24,10 @@ def serverlessSpark(self): oldSetMethod = sparkSession.conf.set oldGetMethod = sparkSession.conf.get + def mock_conf_set(*args, **kwargs): raise ValueError("Setting value prohibited in simulated serverless env.") - + def mock_conf_get(config_key, default=None): # Allow internal PySpark configuration calls that are needed for basic operation whitelisted_configs = { @@ -34,7 +35,7 @@ def mock_conf_get(config_key, default=None): 'spark.sql.execution.arrow.enabled': 'false', 'spark.sql.execution.arrow.pyspark.enabled': 'false', 'spark.python.sql.dataFrameDebugging.enabled': 'true', - 'spark.sql.execution.arrow.maxRecordsPerBatch': '10000' + 'spark.sql.execution.arrow.maxRecordsPerBatch': '10000', } if config_key in whitelisted_configs: try: @@ -43,7 +44,7 @@ def mock_conf_get(config_key, default=None): return whitelisted_configs[config_key] else: raise ValueError("Getting value prohibited in simulated serverless env.") - + sparkSession.conf.set = MagicMock(side_effect=mock_conf_set) sparkSession.conf.get = MagicMock(side_effect=mock_conf_get) @@ -69,23 +70,24 @@ def test_basic_data(self, serverlessSpark): .withColumn("code1", IntegerType(), minValue=100, maxValue=200) .withColumn("code2", "integer", minValue=0, maxValue=10, random=True) .withColumn("code3", StringType(), values=["online", "offline", "unknown"]) - .withColumn( - "code4", StringType(), values=["a", "b", "c"], random=True, percentNulls=0.05 - ) - .withColumn( - "code5", "string", values=["a", "b", "c"], random=True, weights=[9, 1, 1] - ) + .withColumn("code4", StringType(), values=["a", "b", "c"], random=True, percentNulls=0.05) + .withColumn("code5", "string", values=["a", "b", "c"], random=True, weights=[9, 1, 1]) ) testDataSpec.build() - @pytest.mark.parametrize("providerName, providerOptions", [ - ("basic/user", {"rows": 50, "partitions": 4, "random": False, "dummyValues": 0}), - ("basic/user", {"rows": 100, "partitions": -1, "random": True, "dummyValues": 0}) - ]) + @pytest.mark.parametrize( + "providerName, providerOptions", + [ + ("basic/user", {"rows": 50, "partitions": 4, "random": False, "dummyValues": 0}), + ("basic/user", {"rows": 100, "partitions": -1, "random": True, "dummyValues": 0}), + ], + ) def test_basic_user_table_retrieval(self, providerName, providerOptions, serverlessSpark): ds = dg.Datasets(serverlessSpark, providerName).get(**providerOptions) - assert ds is not None, f"""expected to get dataset specification for provider `{providerName}` + assert ( + ds is not None + ), f"""expected to get dataset specification for provider `{providerName}` with options: {providerOptions} """ df = ds.build() diff --git a/tests/test_shared_env.py b/tests/test_shared_env.py index 8efca1c8..04027dbf 100644 --- a/tests/test_shared_env.py +++ b/tests/test_shared_env.py @@ -19,6 +19,7 @@ class TestSharedEnv: """ + SMALL_ROW_COUNT = 100000 COLUMN_COUNT = 10 diff --git a/tests/test_standard_dataset_providers.py b/tests/test_standard_dataset_providers.py index 1123bf8c..8e23bb19 100644 --- a/tests/test_standard_dataset_providers.py +++ b/tests/test_standard_dataset_providers.py @@ -7,38 +7,87 @@ class TestStandardDatasetProviders: - + # BASIC GEOMETRIES tests: - @pytest.mark.parametrize("providerName, providerOptions, expectation", [ - ("basic/geometries", {}, does_not_raise()), - ("basic/geometries", {"rows": 50, "partitions": 4, "random": False, - "geometryType": "point", "maxVertices": 1}, does_not_raise()), - ("basic/geometries", {"rows": 100, "partitions": -1, "random": False, - "geometryType": "point", "maxVertices": 2}, does_not_raise()), - ("basic/geometries", {"rows": -1, "partitions": 4, "random": True, - "geometryType": "point"}, does_not_raise()), - ("basic/geometries", {"rows": 5000, "partitions": -1, "random": True, - "geometryType": "lineString"}, does_not_raise()), - ("basic/geometries", {"rows": -1, "partitions": -1, "random": False, - "geometryType": "lineString", "maxVertices": 2}, does_not_raise()), - ("basic/geometries", {"rows": -1, "partitions": 4, "random": True, - "geometryType": "lineString", "maxVertices": 1}, does_not_raise()), - ("basic/geometries", {"rows": 5000, "partitions": 4, - "geometryType": "lineString", "maxVertices": 2}, does_not_raise()), - ("basic/geometries", {"rows": 5000, "partitions": -1, "random": False, - "geometryType": "polygon"}, does_not_raise()), - ("basic/geometries", {"rows": -1, "partitions": -1, "random": True, - "geometryType": "polygon", "maxVertices": 3}, does_not_raise()), - ("basic/geometries", {"rows": -1, "partitions": 4, "random": True, - "geometryType": "polygon", "maxVertices": 2}, does_not_raise()), - ("basic/geometries", {"rows": 5000, "partitions": 4, - "geometryType": "polygon", "maxVertices": 5}, does_not_raise()), - ("basic/geometries", - {"rows": 5000, "partitions": 4, "geometryType": "polygon", "minLatitude": 45.0, - "maxLatitude": 50.0, "minLongitude": -85.0, "maxLongitude": -75.0}, does_not_raise()), - ("basic/geometries", - {"rows": -1, "partitions": -1, "geometryType": "multipolygon"}, pytest.raises(ValueError)) - ]) + @pytest.mark.parametrize( + "providerName, providerOptions, expectation", + [ + ("basic/geometries", {}, does_not_raise()), + ( + "basic/geometries", + {"rows": 50, "partitions": 4, "random": False, "geometryType": "point", "maxVertices": 1}, + does_not_raise(), + ), + ( + "basic/geometries", + {"rows": 100, "partitions": -1, "random": False, "geometryType": "point", "maxVertices": 2}, + does_not_raise(), + ), + ( + "basic/geometries", + {"rows": -1, "partitions": 4, "random": True, "geometryType": "point"}, + does_not_raise(), + ), + ( + "basic/geometries", + {"rows": 5000, "partitions": -1, "random": True, "geometryType": "lineString"}, + does_not_raise(), + ), + ( + "basic/geometries", + {"rows": -1, "partitions": -1, "random": False, "geometryType": "lineString", "maxVertices": 2}, + does_not_raise(), + ), + ( + "basic/geometries", + {"rows": -1, "partitions": 4, "random": True, "geometryType": "lineString", "maxVertices": 1}, + does_not_raise(), + ), + ( + "basic/geometries", + {"rows": 5000, "partitions": 4, "geometryType": "lineString", "maxVertices": 2}, + does_not_raise(), + ), + ( + "basic/geometries", + {"rows": 5000, "partitions": -1, "random": False, "geometryType": "polygon"}, + does_not_raise(), + ), + ( + "basic/geometries", + {"rows": -1, "partitions": -1, "random": True, "geometryType": "polygon", "maxVertices": 3}, + does_not_raise(), + ), + ( + "basic/geometries", + {"rows": -1, "partitions": 4, "random": True, "geometryType": "polygon", "maxVertices": 2}, + does_not_raise(), + ), + ( + "basic/geometries", + {"rows": 5000, "partitions": 4, "geometryType": "polygon", "maxVertices": 5}, + does_not_raise(), + ), + ( + "basic/geometries", + { + "rows": 5000, + "partitions": 4, + "geometryType": "polygon", + "minLatitude": 45.0, + "maxLatitude": 50.0, + "minLongitude": -85.0, + "maxLongitude": -75.0, + }, + does_not_raise(), + ), + ( + "basic/geometries", + {"rows": -1, "partitions": -1, "geometryType": "multipolygon"}, + pytest.raises(ValueError), + ), + ], + ) def test_basic_geometries_retrieval(self, providerName, providerOptions, expectation): with expectation: ds = dg.Datasets(spark, providerName).get(**providerOptions) @@ -67,44 +116,126 @@ def test_basic_geometries_retrieval(self, providerName, providerOptions, expecta assert ids != sorted(ids) # BASIC PROCESS HISTORIAN tests: - @pytest.mark.parametrize("providerName, providerOptions", [ - ("basic/process_historian", - {"rows": 50, "partitions": 4, "random": False, "numDevices": 1, "numPlants": 1, - "numTags": 1, "startTimestamp": "2020-01-01 00:00:00", "endTimestamp": "2020-04-01 00:00:00"}), - ("basic/process_historian", - {"rows": 1000, "partitions": -1, "random": True, "numDevices": 10, "numPlants": 2, - "numTags": 2, "startTimestamp": "2020-01-01 00:00:00", "endTimestamp": "2020-04-01 00:00:00"}), - ("basic/process_historian", - {"rows": 5000, "partitions": -1, "random": True, "numDevices": 100, "numPlants": 10, - "numTags": 5, "startTimestamp": "2020-01-01 00:00:00", "endTimestamp": "2020-04-01 00:00:00"}), - ("basic/process_historian", {}), - ("basic/process_historian", - {"rows": 5000, "partitions": -1, "random": True, "numDevices": 100, "numPlants": 10, - "numTags": 5, "startTimestamp": "2020-04-01 00:00:00", "endTimestamp": "2020-01-01 00:00:00"}), - ("basic/process_historian", - {"rows": 100, "partitions": -1, "random": True, "numDevices": 100, "numPlants": 10, - "numTags": 5, "startTimestamp": "2020-01-01 00:00:00", "endTimestamp": "2020-04-01 00:00:00"}), - ("basic/process_historian", - {"rows": 100, "partitions": -1, "random": True, "numDevices": 100, "numPlants": 10, - "numTags": 5, "startTimestamp": "2020-01-01 00:00:00", "endTimestamp": "2020-04-01 00:00:00", - "dataQualityRatios": {"pctQuestionable": 0.1, "pctAnnotated": 0.05, "pctSubstituded": 0.12}}), - ("basic/process_historian", - {"rows": 100, "partitions": -1, "random": True, "numDevices": 100, "numPlants": 10, - "numTags": 5, "startTimestamp": "2020-01-01 00:00:00", "endTimestamp": "2020-04-01 00:00:00", - "dataQualityRatios": {"pctQuestionable": 0.1, "pctSubstituded": 0.12}}), - ("basic/process_historian", - {"rows": 100, "partitions": -1, "random": True, "numDevices": 100, "numPlants": 10, - "numTags": 5, "startTimestamp": "2020-01-01 00:00:00", "endTimestamp": "2020-04-01 00:00:00", - "dataQualityRatios": {"pctAnnotated": 0.05}}), - - ]) + @pytest.mark.parametrize( + "providerName, providerOptions", + [ + ( + "basic/process_historian", + { + "rows": 50, + "partitions": 4, + "random": False, + "numDevices": 1, + "numPlants": 1, + "numTags": 1, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + }, + ), + ( + "basic/process_historian", + { + "rows": 1000, + "partitions": -1, + "random": True, + "numDevices": 10, + "numPlants": 2, + "numTags": 2, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + }, + ), + ( + "basic/process_historian", + { + "rows": 5000, + "partitions": -1, + "random": True, + "numDevices": 100, + "numPlants": 10, + "numTags": 5, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + }, + ), + ("basic/process_historian", {}), + ( + "basic/process_historian", + { + "rows": 5000, + "partitions": -1, + "random": True, + "numDevices": 100, + "numPlants": 10, + "numTags": 5, + "startTimestamp": "2020-04-01 00:00:00", + "endTimestamp": "2020-01-01 00:00:00", + }, + ), + ( + "basic/process_historian", + { + "rows": 100, + "partitions": -1, + "random": True, + "numDevices": 100, + "numPlants": 10, + "numTags": 5, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + }, + ), + ( + "basic/process_historian", + { + "rows": 100, + "partitions": -1, + "random": True, + "numDevices": 100, + "numPlants": 10, + "numTags": 5, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + "dataQualityRatios": {"pctQuestionable": 0.1, "pctAnnotated": 0.05, "pctSubstituded": 0.12}, + }, + ), + ( + "basic/process_historian", + { + "rows": 100, + "partitions": -1, + "random": True, + "numDevices": 100, + "numPlants": 10, + "numTags": 5, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + "dataQualityRatios": {"pctQuestionable": 0.1, "pctSubstituded": 0.12}, + }, + ), + ( + "basic/process_historian", + { + "rows": 100, + "partitions": -1, + "random": True, + "numDevices": 100, + "numPlants": 10, + "numTags": 5, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + "dataQualityRatios": {"pctAnnotated": 0.05}, + }, + ), + ], + ) def test_basic_process_historian_retrieval(self, providerName, providerOptions): ds = dg.Datasets(spark, providerName).get(**providerOptions) assert ds is not None df = ds.build() assert df.count() >= 0 - + startTimestamp = providerOptions.get("startTimestamp", "2024-01-01 00:00:00") endTimestamp = providerOptions.get("endTimestamp", "2024-02-01 00:00:00") if startTimestamp > endTimestamp: @@ -120,37 +251,62 @@ def test_basic_process_historian_retrieval(self, providerName, providerOptions): assert ids != sorted(ids) # BASIC STOCK TICKER tests: - @pytest.mark.parametrize("providerName, providerOptions, expectation", [ - ("basic/stock_ticker", - {"rows": 50, "partitions": 4, "numSymbols": 5, "startDate": "2024-01-01"}, does_not_raise()), - ("basic/stock_ticker", - {"rows": 100, "partitions": -1, "numSymbols": 5, "startDate": "2024-01-01"}, does_not_raise()), - ("basic/stock_ticker", - {"rows": -1, "partitions": 4, "numSymbols": 10, "startDate": "2024-01-01"}, does_not_raise()), - ("basic/stock_ticker", {}, does_not_raise()), - ("basic/stock_ticker", - {"rows": 5000, "partitions": -1, "numSymbols": 50, "startDate": "2024-01-01"}, does_not_raise()), - ("basic/stock_ticker", - {"rows": 5000, "partitions": 4, "numSymbols": 50}, does_not_raise()), - ("basic/stock_ticker", - {"rows": 5000, "partitions": 4, "startDate": "2024-01-01"}, does_not_raise()), - ("basic/stock_ticker", - {"rows": 5000, "partitions": 4, "numSymbols": 100, "startDate": "2024-01-01"}, does_not_raise()), - ("basic/stock_ticker", - {"rows": 1000, "partitions": -1, "numSymbols": 100, "startDate": "2025-01-01"}, does_not_raise()), - ("basic/stock_ticker", - {"rows": 1000, "partitions": -1, "numSymbols": 10, "startDate": "2020-01-01"}, does_not_raise()), - ("basic/stock_ticker", - {"rows": 50, "partitions": 2, "numSymbols": 0, "startDate": "2020-06-04"}, pytest.raises(ValueError)), - ("basic/stock_ticker", - {"rows": 500, "numSymbols": 12, "startDate": "2025-06-04"}, does_not_raise()), - ("basic/stock_ticker", - {"rows": 10, "partitions": 1, "numSymbols": -1, "startDate": "2009-01-02"}, pytest.raises(ValueError)), - ("basic/stock_ticker", - {"partitions": 2, "numSymbols": 20, "startDate": "2021-01-01"}, does_not_raise()), - ("basic/stock_ticker", - {"rows": 50, "partitions": 2, "numSymbols": 2}, does_not_raise()), - ]) + @pytest.mark.parametrize( + "providerName, providerOptions, expectation", + [ + ( + "basic/stock_ticker", + {"rows": 50, "partitions": 4, "numSymbols": 5, "startDate": "2024-01-01"}, + does_not_raise(), + ), + ( + "basic/stock_ticker", + {"rows": 100, "partitions": -1, "numSymbols": 5, "startDate": "2024-01-01"}, + does_not_raise(), + ), + ( + "basic/stock_ticker", + {"rows": -1, "partitions": 4, "numSymbols": 10, "startDate": "2024-01-01"}, + does_not_raise(), + ), + ("basic/stock_ticker", {}, does_not_raise()), + ( + "basic/stock_ticker", + {"rows": 5000, "partitions": -1, "numSymbols": 50, "startDate": "2024-01-01"}, + does_not_raise(), + ), + ("basic/stock_ticker", {"rows": 5000, "partitions": 4, "numSymbols": 50}, does_not_raise()), + ("basic/stock_ticker", {"rows": 5000, "partitions": 4, "startDate": "2024-01-01"}, does_not_raise()), + ( + "basic/stock_ticker", + {"rows": 5000, "partitions": 4, "numSymbols": 100, "startDate": "2024-01-01"}, + does_not_raise(), + ), + ( + "basic/stock_ticker", + {"rows": 1000, "partitions": -1, "numSymbols": 100, "startDate": "2025-01-01"}, + does_not_raise(), + ), + ( + "basic/stock_ticker", + {"rows": 1000, "partitions": -1, "numSymbols": 10, "startDate": "2020-01-01"}, + does_not_raise(), + ), + ( + "basic/stock_ticker", + {"rows": 50, "partitions": 2, "numSymbols": 0, "startDate": "2020-06-04"}, + pytest.raises(ValueError), + ), + ("basic/stock_ticker", {"rows": 500, "numSymbols": 12, "startDate": "2025-06-04"}, does_not_raise()), + ( + "basic/stock_ticker", + {"rows": 10, "partitions": 1, "numSymbols": -1, "startDate": "2009-01-02"}, + pytest.raises(ValueError), + ), + ("basic/stock_ticker", {"partitions": 2, "numSymbols": 20, "startDate": "2021-01-01"}, does_not_raise()), + ("basic/stock_ticker", {"rows": 50, "partitions": 2, "numSymbols": 2}, does_not_raise()), + ], + ) def test_basic_stock_ticker_retrieval(self, providerName, providerOptions, expectation): with expectation: ds = dg.Datasets(spark, providerName).get(**providerOptions) @@ -162,58 +318,180 @@ def test_basic_stock_ticker_retrieval(self, providerName, providerOptions, expec assert df.selectExpr("symbol").distinct().count() == providerOptions.get("numSymbols") if "startDate" in providerOptions: - assert df.selectExpr("min(post_date) as min_post_date") \ - .collect()[0] \ - .asDict()["min_post_date"] == date.fromisoformat(providerOptions.get("startDate")) + assert df.selectExpr("min(post_date) as min_post_date").collect()[0].asDict()[ + "min_post_date" + ] == date.fromisoformat(providerOptions.get("startDate")) - assert df.where("""open < 0.0 + assert ( + df.where( + """open < 0.0 or close < 0.0 or high < 0.0 or low < 0.0 - or adj_close < 0.0""").count() == 0 + or adj_close < 0.0""" + ).count() + == 0 + ) assert df.where("high < low").count() == 0 # BASIC TELEMATICS tests: - @pytest.mark.parametrize("providerName, providerOptions", [ - ("basic/telematics", - {"rows": 50, "partitions": 4, "random": False, "numDevices": 5000, "startTimestamp": "2020-01-01 00:00:00", - "endTimestamp": "2020-04-01 00:00:00", "minLat": 40.0, "maxLat": 43.0, "minLon": -93.0, "maxLon": -89.0, - "generateWkt": False}), - ("basic/telematics", - {"rows": 1000, "partitions": 4, "random": True, "numDevices": 1000, "startTimestamp": "2020-01-01 00:00:00", - "endTimestamp": "2020-04-01 00:00:00", "minLat": 45.0, "maxLat": 35.0, "minLon": -89.0, "maxLon": -93.0, - "generateWkt": True}), - ("basic/telematics", - {"rows": -1, "partitions": -1, "numDevices": 1000, "minLat": 98.0, "maxLat": 100.0, - "minLon": -181.0, "maxLon": -185.0, "generateWkt": False}), - ("basic/telematics", - {"rows": 5000, "partitions": -1, "startTimestamp": "2020-01-01 00:00:00", - "endTimestamp": "2020-04-01 00:00:00", "generateWkt": True}), - ("basic/telematics", {}), - ("basic/telematics", - {"rows": -1, "partitions": -1, "random": False, "numDevices": 50, "startTimestamp": "2020-06-01 00:00:00", - "endTimestamp": "2020-04-01 00:00:00", "minLat": 40.0, "maxLat": 43.0, "minLon": -93.0, "maxLon": -89.0, - "generateWkt": False}), - ("basic/telematics", - {"rows": -1, "partitions": -1, "random": False, "numDevices": 100, "startTimestamp": "2020-01-01 00:00:00", - "endTimestamp": "2020-04-01 00:00:00", "maxLat": 45.0, "minLon": -93.0, "generateWkt": False}), - ("basic/telematics", - {"rows": -1, "partitions": -1, "random": False, "numDevices": 100, "startTimestamp": "2020-01-01 00:00:00", - "endTimestamp": "2020-04-01 00:00:00", "minLat": 45.0, "generateWkt": False}), - ("basic/telematics", - {"rows": -1, "partitions": -1, "random": False, "numDevices": 100, "startTimestamp": "2020-01-01 00:00:00", - "endTimestamp": "2020-04-01 00:00:00", "minLat": -120.0, "generateWkt": False}), - ("basic/telematics", - {"rows": -1, "partitions": -1, "random": False, "numDevices": 100, "startTimestamp": "2020-01-01 00:00:00", - "endTimestamp": "2020-04-01 00:00:00", "maxLat": -120.0, "generateWkt": False}), - ("basic/telematics", - {"rows": -1, "partitions": -1, "random": False, "numDevices": 100, "startTimestamp": "2020-01-01 00:00:00", - "endTimestamp": "2020-04-01 00:00:00", "minLon": 190.0, "generateWkt": False}), - ("basic/telematics", - {"rows": -1, "partitions": -1, "random": False, "numDevices": 100, "startTimestamp": "2020-01-01 00:00:00", - "endTimestamp": "2020-04-01 00:00:00", "maxLon": 190.0, "generateWkt": False}), - ]) + @pytest.mark.parametrize( + "providerName, providerOptions", + [ + ( + "basic/telematics", + { + "rows": 50, + "partitions": 4, + "random": False, + "numDevices": 5000, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + "minLat": 40.0, + "maxLat": 43.0, + "minLon": -93.0, + "maxLon": -89.0, + "generateWkt": False, + }, + ), + ( + "basic/telematics", + { + "rows": 1000, + "partitions": 4, + "random": True, + "numDevices": 1000, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + "minLat": 45.0, + "maxLat": 35.0, + "minLon": -89.0, + "maxLon": -93.0, + "generateWkt": True, + }, + ), + ( + "basic/telematics", + { + "rows": -1, + "partitions": -1, + "numDevices": 1000, + "minLat": 98.0, + "maxLat": 100.0, + "minLon": -181.0, + "maxLon": -185.0, + "generateWkt": False, + }, + ), + ( + "basic/telematics", + { + "rows": 5000, + "partitions": -1, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + "generateWkt": True, + }, + ), + ("basic/telematics", {}), + ( + "basic/telematics", + { + "rows": -1, + "partitions": -1, + "random": False, + "numDevices": 50, + "startTimestamp": "2020-06-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + "minLat": 40.0, + "maxLat": 43.0, + "minLon": -93.0, + "maxLon": -89.0, + "generateWkt": False, + }, + ), + ( + "basic/telematics", + { + "rows": -1, + "partitions": -1, + "random": False, + "numDevices": 100, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + "maxLat": 45.0, + "minLon": -93.0, + "generateWkt": False, + }, + ), + ( + "basic/telematics", + { + "rows": -1, + "partitions": -1, + "random": False, + "numDevices": 100, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + "minLat": 45.0, + "generateWkt": False, + }, + ), + ( + "basic/telematics", + { + "rows": -1, + "partitions": -1, + "random": False, + "numDevices": 100, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + "minLat": -120.0, + "generateWkt": False, + }, + ), + ( + "basic/telematics", + { + "rows": -1, + "partitions": -1, + "random": False, + "numDevices": 100, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + "maxLat": -120.0, + "generateWkt": False, + }, + ), + ( + "basic/telematics", + { + "rows": -1, + "partitions": -1, + "random": False, + "numDevices": 100, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + "minLon": 190.0, + "generateWkt": False, + }, + ), + ( + "basic/telematics", + { + "rows": -1, + "partitions": -1, + "random": False, + "numDevices": 100, + "startTimestamp": "2020-01-01 00:00:00", + "endTimestamp": "2020-04-01 00:00:00", + "maxLon": 190.0, + "generateWkt": False, + }, + ), + ], + ) def test_basic_telematics_retrieval(self, providerName, providerOptions): ds = dg.Datasets(spark, providerName).get(**providerOptions) assert ds is not None @@ -272,17 +550,22 @@ def test_basic_telematics_retrieval(self, providerName, providerOptions): assert ids != sorted(ids) # BASIC USER tests: - @pytest.mark.parametrize("providerName, providerOptions", [ - ("basic/user", {"rows": 50, "partitions": 4, "random": False, "dummyValues": 0}), - ("basic/user", {"rows": -1, "partitions": 4, "random": False, "dummyValues": 0}), - ("basic/user", {}), - ("basic/user", {"rows": 100, "partitions": -1, "random": False, "dummyValues": 10}), - ("basic/user", {"rows": 5000, "dummyValues": 4}), - ("basic/user", {"rows": 100, "partitions": -1, "random": True, "dummyValues": 0}), - ]) + @pytest.mark.parametrize( + "providerName, providerOptions", + [ + ("basic/user", {"rows": 50, "partitions": 4, "random": False, "dummyValues": 0}), + ("basic/user", {"rows": -1, "partitions": 4, "random": False, "dummyValues": 0}), + ("basic/user", {}), + ("basic/user", {"rows": 100, "partitions": -1, "random": False, "dummyValues": 10}), + ("basic/user", {"rows": 5000, "dummyValues": 4}), + ("basic/user", {"rows": 100, "partitions": -1, "random": True, "dummyValues": 0}), + ], + ) def test_basic_user_table_retrieval(self, providerName, providerOptions): ds = dg.Datasets(spark, providerName).get(**providerOptions) - assert ds is not None, f"""expected to get dataset specification for provider `{providerName}` + assert ( + ds is not None + ), f"""expected to get dataset specification for provider `{providerName}` with options: {providerOptions} """ df = ds.build() @@ -297,19 +580,28 @@ def test_basic_user_table_retrieval(self, providerName, providerOptions): assert customer_ids != sorted(customer_ids) # BENCHMARK GROUPBY tests: - @pytest.mark.parametrize("providerName, providerOptions", [ - ("benchmark/groupby", {"rows": 50, "partitions": 4, "random": False, "groups": 10, "percentNulls": 0.1}), - ("benchmark/groupby", {"rows": -1, "partitions": 4, "random": True, "groups": 100}), - ("benchmark/groupby", {}), - ("benchmark/groupby", {"rows": 1000, "partitions": -1, "random": False}), - ("benchmark/groupby", {"rows": -1, "groups": 1000, "percentNulls": 0.2}), - ("benchmark/groupby", {"rows": 1000, "partitions": -1, "random": True, "groups": 5000, "percentNulls": 0.5}), - ("benchmark/groupby", {"rows": -1, "partitions": -1, "random": True, "groups": 0}), - ("benchmark/groupby", {"rows": 10, "partitions": -1, "random": True, "groups": 100, "percentNulls": 0.1}), - ("benchmark/groupby", {"rows": -1, "partitions": -1, "random": False, "groups": -50}), - ("benchmark/groupby", {"rows": -1, "partitions": -1, "random": False, "groups": -50, "percentNulls": -12.1}), - ("benchmark/groupby", {"rows": -1, "partitions": -1, "random": True, "groups": -50, "percentNulls": 1.1}), - ]) + @pytest.mark.parametrize( + "providerName, providerOptions", + [ + ("benchmark/groupby", {"rows": 50, "partitions": 4, "random": False, "groups": 10, "percentNulls": 0.1}), + ("benchmark/groupby", {"rows": -1, "partitions": 4, "random": True, "groups": 100}), + ("benchmark/groupby", {}), + ("benchmark/groupby", {"rows": 1000, "partitions": -1, "random": False}), + ("benchmark/groupby", {"rows": -1, "groups": 1000, "percentNulls": 0.2}), + ( + "benchmark/groupby", + {"rows": 1000, "partitions": -1, "random": True, "groups": 5000, "percentNulls": 0.5}, + ), + ("benchmark/groupby", {"rows": -1, "partitions": -1, "random": True, "groups": 0}), + ("benchmark/groupby", {"rows": 10, "partitions": -1, "random": True, "groups": 100, "percentNulls": 0.1}), + ("benchmark/groupby", {"rows": -1, "partitions": -1, "random": False, "groups": -50}), + ( + "benchmark/groupby", + {"rows": -1, "partitions": -1, "random": False, "groups": -50, "percentNulls": -12.1}, + ), + ("benchmark/groupby", {"rows": -1, "partitions": -1, "random": True, "groups": -50, "percentNulls": 1.1}), + ], + ) def test_benchmark_groupby_retrieval(self, providerName, providerOptions): ds = dg.Datasets(spark, providerName).get(**providerOptions) assert ds is not None @@ -329,57 +621,101 @@ def test_benchmark_groupby_retrieval(self, providerName, providerOptions): assert vals != sorted(vals) # MULTI-TABLE SALES ORDER tests: - @pytest.mark.parametrize("providerName, providerOptions, expectation", [ - ("multi_table/sales_order", {"rows": 50, "partitions": 4}, does_not_raise()), - ("multi_table/sales_order", {"rows": -1, "partitions": 4}, does_not_raise()), - ("multi_table/sales_order", {}, does_not_raise()), - ("multi_table/sales_order", {"rows": 100, "partitions": -1}, does_not_raise()), - ("multi_table/sales_order", {"rows": 5000, "dummyValues": 4}, does_not_raise()), - ("multi_table/sales_order", {"rows": 100, "partitions": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "customers", "numCustomers": 100}, does_not_raise()), - ("multi_table/sales_order", {"table": "customers", "numCustomers": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "customers", "rows": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "customers", "rows": -1, "partitions": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "carriers", "numCarriers": 50}, does_not_raise()), - ("multi_table/sales_order", {"table": "carriers", "numCarriers": -1, "dummyValues": 2}, does_not_raise()), - ("multi_table/sales_order", {"table": "carriers", "numCustomers": 100}, does_not_raise()), - ("multi_table/sales_order", {"table": "carriers", "rows": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "carriers", "rows": -1, "partitions": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "catalog_items", "numCatalogItems": -1, - "dummyValues": 5}, does_not_raise()), - ("multi_table/sales_order", {"table": "catalog_items", "numCatalogItems": 100, - "numCustomers": 1000}, does_not_raise()), - ("multi_table/sales_order", {"table": "catalog_items", "rows": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "catalog_items", "rows": -1, "partitions": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_orders", "rows": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_orders", "rows": -1, "partitions": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_orders", "numOrders": -1, "numCustomers": -1, "startDate": None, - "endDate": None}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_orders", "numOrders": 1000, "numCustomers": 10, - "dummyValues": 2}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_order_line_items", "rows": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_order_line_items", "rows": -1, "partitions": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_order_line_items", "numOrders": 1000, - "dummyValues": 5}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_order_line_items", "numOrders": -1, "numCatalogItems": -1, - "lineItemsPerOrder": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_order_shipments", "rows": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_order_shipments", "rows": -1, "partitions": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_order_shipments", "numOrders": 1000, - "numCarriers": 10}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_order_shipments", "numOrders": -1, "numCarriers": -1, - "dummyValues": 2}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_invoices", "rows": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_invoices", "rows": -1, "partitions": -1}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_invoices", "numOrders": 1000, - "numCustomers": 10}, does_not_raise()), - ("multi_table/sales_order", {"table": "base_invoices", "numOrders": -1, "dummyValues": 2}, does_not_raise()), - ("multi_table/sales_order", {"table": "invalid_table_name"}, pytest.raises(ValueError)) - ]) + @pytest.mark.parametrize( + "providerName, providerOptions, expectation", + [ + ("multi_table/sales_order", {"rows": 50, "partitions": 4}, does_not_raise()), + ("multi_table/sales_order", {"rows": -1, "partitions": 4}, does_not_raise()), + ("multi_table/sales_order", {}, does_not_raise()), + ("multi_table/sales_order", {"rows": 100, "partitions": -1}, does_not_raise()), + ("multi_table/sales_order", {"rows": 5000, "dummyValues": 4}, does_not_raise()), + ("multi_table/sales_order", {"rows": 100, "partitions": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "customers", "numCustomers": 100}, does_not_raise()), + ("multi_table/sales_order", {"table": "customers", "numCustomers": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "customers", "rows": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "customers", "rows": -1, "partitions": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "carriers", "numCarriers": 50}, does_not_raise()), + ("multi_table/sales_order", {"table": "carriers", "numCarriers": -1, "dummyValues": 2}, does_not_raise()), + ("multi_table/sales_order", {"table": "carriers", "numCustomers": 100}, does_not_raise()), + ("multi_table/sales_order", {"table": "carriers", "rows": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "carriers", "rows": -1, "partitions": -1}, does_not_raise()), + ( + "multi_table/sales_order", + {"table": "catalog_items", "numCatalogItems": -1, "dummyValues": 5}, + does_not_raise(), + ), + ( + "multi_table/sales_order", + {"table": "catalog_items", "numCatalogItems": 100, "numCustomers": 1000}, + does_not_raise(), + ), + ("multi_table/sales_order", {"table": "catalog_items", "rows": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "catalog_items", "rows": -1, "partitions": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_orders", "rows": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_orders", "rows": -1, "partitions": -1}, does_not_raise()), + ( + "multi_table/sales_order", + {"table": "base_orders", "numOrders": -1, "numCustomers": -1, "startDate": None, "endDate": None}, + does_not_raise(), + ), + ( + "multi_table/sales_order", + {"table": "base_orders", "numOrders": 1000, "numCustomers": 10, "dummyValues": 2}, + does_not_raise(), + ), + ("multi_table/sales_order", {"table": "base_order_line_items", "rows": -1}, does_not_raise()), + ( + "multi_table/sales_order", + {"table": "base_order_line_items", "rows": -1, "partitions": -1}, + does_not_raise(), + ), + ( + "multi_table/sales_order", + {"table": "base_order_line_items", "numOrders": 1000, "dummyValues": 5}, + does_not_raise(), + ), + ( + "multi_table/sales_order", + {"table": "base_order_line_items", "numOrders": -1, "numCatalogItems": -1, "lineItemsPerOrder": -1}, + does_not_raise(), + ), + ("multi_table/sales_order", {"table": "base_order_shipments", "rows": -1}, does_not_raise()), + ( + "multi_table/sales_order", + {"table": "base_order_shipments", "rows": -1, "partitions": -1}, + does_not_raise(), + ), + ( + "multi_table/sales_order", + {"table": "base_order_shipments", "numOrders": 1000, "numCarriers": 10}, + does_not_raise(), + ), + ( + "multi_table/sales_order", + {"table": "base_order_shipments", "numOrders": -1, "numCarriers": -1, "dummyValues": 2}, + does_not_raise(), + ), + ("multi_table/sales_order", {"table": "base_invoices", "rows": -1}, does_not_raise()), + ("multi_table/sales_order", {"table": "base_invoices", "rows": -1, "partitions": -1}, does_not_raise()), + ( + "multi_table/sales_order", + {"table": "base_invoices", "numOrders": 1000, "numCustomers": 10}, + does_not_raise(), + ), + ( + "multi_table/sales_order", + {"table": "base_invoices", "numOrders": -1, "dummyValues": 2}, + does_not_raise(), + ), + ("multi_table/sales_order", {"table": "invalid_table_name"}, pytest.raises(ValueError)), + ], + ) def test_multi_table_sales_order_retrieval(self, providerName, providerOptions, expectation): with expectation: ds = dg.Datasets(spark, providerName).get(**providerOptions) - assert ds is not None, f"""expected to get dataset specification for provider `{providerName}` + assert ( + ds is not None + ), f"""expected to get dataset specification for provider `{providerName}` with options: {providerOptions} """ df = ds.build() @@ -387,8 +723,15 @@ def test_multi_table_sales_order_retrieval(self, providerName, providerOptions, def test_full_multitable_sales_order_sequence(self): multiTableDataSet = dg.Datasets(spark, "multi_table/sales_order") - options = {"numCustomers": 100, "numOrders": 1000, "numCarriers": 10, "numCatalogItems": 100, - "startDate": "2024-01-01", "endDate": "2024-12-31", "lineItemsPerOrder": 3} + options = { + "numCustomers": 100, + "numOrders": 1000, + "numCarriers": 10, + "numCatalogItems": 100, + "startDate": "2024-01-01", + "endDate": "2024-12-31", + "lineItemsPerOrder": 3, + } dfCustomers = multiTableDataSet.get(table="customers", **options).build() dfCarriers = multiTableDataSet.get(table="carriers", **options).build() dfCatalogItems = multiTableDataSet.get(table="catalog_items", **options).build() @@ -407,7 +750,7 @@ def test_full_multitable_sales_order_sequence(self): baseOrders=dfBaseOrders, baseOrderLineItems=dfBaseOrderLineItems, baseOrderShipments=dfBaseOrderShipments, - baseInvoices=dfBaseInvoices + baseInvoices=dfBaseInvoices, ) assert df is not None @@ -415,25 +758,30 @@ def test_full_multitable_sales_order_sequence(self): assert df # MULTI-TABLE TELEPHONY tests: - @pytest.mark.parametrize("providerName, providerOptions", [ - ("multi_table/telephony", {"rows": 50, "partitions": 4, "random": False}), - ("multi_table/telephony", {"rows": -1, "partitions": 4, "random": False}), - ("multi_table/telephony", {}), - ("multi_table/telephony", {"rows": 100, "partitions": -1, "random": False}), - ("multi_table/telephony", {"rows": 5000, "dummyValues": 4}), - ("multi_table/telephony", {"rows": 100, "partitions": -1, "random": True}), - ("multi_table/telephony", {"table": 'plans', "numPlans": 100}), - ("multi_table/telephony", {"table": 'plans'}), - ("multi_table/telephony", {"table": 'customers', "numPlans": 100, "numCustomers": 1000}), - ("multi_table/telephony", {"table": 'customers', "numPlans": 100, "numCustomers": 1000}), - ("multi_table/telephony", {"table": 'customers'}), - ("multi_table/telephony", {"table": 'deviceEvents', "numPlans": 100, "numCustomers": 1000}), - ("multi_table/telephony", {"table": 'deviceEvents'}), - ("multi_table/telephony", {"table": 'deviceEvents', "numDays": 10}), - ]) + @pytest.mark.parametrize( + "providerName, providerOptions", + [ + ("multi_table/telephony", {"rows": 50, "partitions": 4, "random": False}), + ("multi_table/telephony", {"rows": -1, "partitions": 4, "random": False}), + ("multi_table/telephony", {}), + ("multi_table/telephony", {"rows": 100, "partitions": -1, "random": False}), + ("multi_table/telephony", {"rows": 5000, "dummyValues": 4}), + ("multi_table/telephony", {"rows": 100, "partitions": -1, "random": True}), + ("multi_table/telephony", {"table": 'plans', "numPlans": 100}), + ("multi_table/telephony", {"table": 'plans'}), + ("multi_table/telephony", {"table": 'customers', "numPlans": 100, "numCustomers": 1000}), + ("multi_table/telephony", {"table": 'customers', "numPlans": 100, "numCustomers": 1000}), + ("multi_table/telephony", {"table": 'customers'}), + ("multi_table/telephony", {"table": 'deviceEvents', "numPlans": 100, "numCustomers": 1000}), + ("multi_table/telephony", {"table": 'deviceEvents'}), + ("multi_table/telephony", {"table": 'deviceEvents', "numDays": 10}), + ], + ) def test_multi_table_retrieval(self, providerName, providerOptions): ds = dg.Datasets(spark, providerName).get(**providerOptions) - assert ds is not None, f"""expected to get dataset specification for provider `{providerName}` + assert ( + ds is not None + ), f"""expected to get dataset specification for provider `{providerName}` with options: {providerOptions} """ df = ds.build() @@ -446,9 +794,9 @@ def test_full_multitable_sequence(self): dfPlans = multiTableDS.get(table="plans", **options).build() dfCustomers = multiTableDS.get(table="customers", **options).build() dfDeviceEvents = multiTableDS.get(table="deviceEvents", **options).build() - dfInvoices = multiTableDS.getSummaryDataset(table="invoices", plans=dfPlans, - customers=dfCustomers, - deviceEvents=dfDeviceEvents) + dfInvoices = multiTableDS.getSummaryDataset( + table="invoices", plans=dfPlans, customers=dfCustomers, deviceEvents=dfDeviceEvents + ) assert dfInvoices is not None assert dfInvoices.count() >= 0 diff --git a/tests/test_standard_datasets.py b/tests/test_standard_datasets.py index 90afac04..69e73148 100644 --- a/tests/test_standard_datasets.py +++ b/tests/test_standard_datasets.py @@ -11,19 +11,25 @@ @pytest.fixture def mkTableSpec(): - dataspec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) - .withIdOutput() - .withColumn("code1", IntegerType(), min=100, max=200) - .withColumn("code2", IntegerType(), min=0, max=10) - ) + dataspec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000) + .withIdOutput() + .withColumn("code1", IntegerType(), min=100, max=200) + .withColumn("code2", IntegerType(), min=0, max=10) + ) return dataspec class TestStandardDatasetsFramework: # Define some dummy providers - we will use these to check if they are found by # the listing and describe methods etc. - @dataset_definition(name="test_providers/test_batch", summary="Test Data Set1", autoRegister=True, - tables=["green", "yellow", "red"], supportsStreaming=False) + @dataset_definition( + name="test_providers/test_batch", + summary="Test Data Set1", + autoRegister=True, + tables=["green", "yellow", "red"], + supportsStreaming=False, + ) class SampleDatasetProviderBatch(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider): def __init__(self): pass @@ -48,40 +54,41 @@ def recordArgs(cls, *, table, options, rows, partitions): cls.lastPartitionsRequested = partitions @DatasetProvider.allowed_options(options=["random", "dummyValues"]) - def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1, - **options): + def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1, **options): generateRandom = options.get("random", True) dummyValues = options.get("dummyValues", 0) - ds = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, seedMethod='hash_fieldname') - .withColumn("code1", "int", min=100, max=200) - .withColumn("code2", "int", min=0, max=10) - .withColumn("code3", "string", values=['a', 'b', 'c']) - .withColumn("code4", "string", values=['a', 'b', 'c'], random=generateRandom) - .withColumn("code5", "string", values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - ) + ds = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, seedMethod='hash_fieldname') + .withColumn("code1", "int", min=100, max=200) + .withColumn("code2", "int", min=0, max=10) + .withColumn("code3", "string", values=['a', 'b', 'c']) + .withColumn("code4", "string", values=['a', 'b', 'c'], random=generateRandom) + .withColumn("code5", "string", values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) if dummyValues > 0: - ds = ds.withColumn("dummy", "long", random=True, numColumns=dummyValues, - minValue=1, maxValue=self.MAX_LONG) + ds = ds.withColumn( + "dummy", "long", random=True, numColumns=dummyValues, minValue=1, maxValue=self.MAX_LONG + ) return ds - @dataset_definition(name="test_providers/test_streaming", summary="Test Data Set2", autoRegister=True, - supportsStreaming=True) + @dataset_definition( + name="test_providers/test_streaming", summary="Test Data Set2", autoRegister=True, supportsStreaming=True + ) class SampleDatasetProviderStreaming(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider): def __init__(self): pass - def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1, - **options): - ds = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, - seedMethod='hash_fieldname') - .withColumn("code1", "int", min=100, max=200) - .withColumn("code2", "int", min=0, max=10) - .withColumn("code3", "string", values=['a', 'b', 'c']) - .withColumn("code4", "string", values=['a', 'b', 'c'], random=True) - .withColumn("code5", "string", values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - ) + def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1, **options): + ds = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000, seedMethod='hash_fieldname') + .withColumn("code1", "int", min=100, max=200) + .withColumn("code2", "int", min=0, max=10) + .withColumn("code3", "string", values=['a', 'b', 'c']) + .withColumn("code4", "string", values=['a', 'b', 'c'], random=True) + .withColumn("code5", "string", values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) return ds @pytest.fixture @@ -94,7 +101,7 @@ def dataset_definition1(self): description="Description of the test dataset", supportsStreaming=True, providerClass=DatasetProvider, - associatedDatasets=None + associatedDatasets=None, ) def test_datasets_bad_table_name(self): @@ -124,6 +131,7 @@ def test_datasets_bad_associated_dataset_name(self): def test_datasets_bad_decorator_usage(self): with pytest.raises(TypeError): + @dataset_definition(name="test_providers/badly_applied_decorator", summary="Bad Usage", autoRegister=True) def by_two(x): return x * 2 @@ -179,6 +187,7 @@ def test_dataset_definition_attributes(self, dataset_definition1): def test_decorators1(self, mkTableSpec): import sys + print("sys.versioninfo", sys.version_info) @dataset_definition @@ -215,9 +224,9 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=1000000, parti def test_decorators1a(self, mkTableSpec): @dataset_definition(name="test/test", tables=["main1"]) class Y1a(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider): - def getTableGenerator(self, sparkSession, *, tableName=None, rows=1000000, partitions=4, - autoSizePartitions=False, - **options): + def getTableGenerator( + self, sparkSession, *, tableName=None, rows=1000000, partitions=4, autoSizePartitions=False, **options + ): return mkTableSpec ds_definition = Y1a.getDatasetDefinition() @@ -241,9 +250,16 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=1000000, parti def test_decorators1b(self, mkTableSpec): @dataset_definition(description="a test description") class X1b(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider): - def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1, - description="a test description", - **options): + def getTableGenerator( + self, + sparkSession, + *, + tableName=None, + rows=-1, + partitions=-1, + description="a test description", + **options, + ): return mkTableSpec ds_definition = X1b.getDatasetDefinition() @@ -274,9 +290,16 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions def test_bad_registration(self, mkTableSpec): @dataset_definition(description="a test description") class X1b(DatasetProvider): - def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1, - description="a test description", - **options): + def getTableGenerator( + self, + sparkSession, + *, + tableName=None, + rows=-1, + partitions=-1, + description="a test description", + **options, + ): return mkTableSpec with pytest.raises(ValueError): @@ -291,23 +314,28 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions def test_invalid_decorator_use(self): with pytest.raises(TypeError): + @dataset_definition def my_function(x): return x def test_invalid_decorator_use2(self): with pytest.raises(TypeError): + @dataset_definition(name="test/bad_decorator1") def my_function(x): return x - @pytest.mark.parametrize("providerClass, options", - [(SampleDatasetProviderBatch, {}), - (SampleDatasetProviderBatch, {"pattern": "test.*"}), - (SampleDatasetProviderBatch, {"pattern": "test_providers/test_batch"}), - (SampleDatasetProviderBatch, {"supportsStreaming": False}), - (SampleDatasetProviderStreaming, {"supportsStreaming": True}) - ]) + @pytest.mark.parametrize( + "providerClass, options", + [ + (SampleDatasetProviderBatch, {}), + (SampleDatasetProviderBatch, {"pattern": "test.*"}), + (SampleDatasetProviderBatch, {"pattern": "test_providers/test_batch"}), + (SampleDatasetProviderBatch, {"supportsStreaming": False}), + (SampleDatasetProviderStreaming, {"supportsStreaming": True}), + ], + ) def test_listing(self, providerClass, options, capsys): print("listing datasets") @@ -334,9 +362,9 @@ def test_describe_basic_usr(self, capsys): @pytest.fixture def dataset_provider(self): class MyDatasetProvider(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider): - def getTableGenerator(self, sparkSession, *, tableName=None, rows=1000000, partitions=4, - autoSizePartitions=False, - **options): + def getTableGenerator( + self, sparkSession, *, tableName=None, rows=1000000, partitions=4, autoSizePartitions=False, **options + ): return mkTableSpec return MyDatasetProvider() @@ -359,13 +387,10 @@ def test_check_options_invalid_options(self, dataset_provider): with pytest.raises(AssertionError): dataset_provider.checkOptions(options, allowed_options) - @pytest.mark.parametrize("rows, columns, expected_partitions", [ - (1000000, 10, 4), - (5000000, 100, 12), - (100, 2, 4), - (1000_000_000, 10, 18), - (5000_000_000, 30, 32) - ]) + @pytest.mark.parametrize( + "rows, columns, expected_partitions", + [(1000000, 10, 4), (5000000, 100, 12), (100, 2, 4), (1000_000_000, 10, 18), (5000_000_000, 30, 32)], + ) def test_auto_compute_partitions(self, dataset_provider, rows, columns, expected_partitions): partitions = dataset_provider.autoComputePartitions(rows, columns) assert partitions == expected_partitions diff --git a/tests/test_streaming.py b/tests/test_streaming.py index f65361cb..a59c210b 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -37,26 +37,34 @@ def test_streaming(self, getStreamingDirs, seedColumnName): base_dir, test_dir, checkpoint_dir = getStreamingDirs if seedColumnName is not None: - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=self.row_count, - partitions=4, seedMethod='hash_fieldname', seedColumnName=seedColumnName)) + testDataSpec = dg.DataGenerator( + sparkSession=spark, + name="test_data_set1", + rows=self.row_count, + partitions=4, + seedMethod='hash_fieldname', + seedColumnName=seedColumnName, + ) else: - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=self.row_count, - partitions=4, seedMethod='hash_fieldname')) - - testDataSpec = (testDataSpec - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=self.column_count) - .withColumn("code1", IntegerType(), minValue=100, maxValue=200) - .withColumn("code2", IntegerType(), minValue=0, maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - - ) - - dfTestData = testDataSpec.build(withStreaming=True, - options={'rowsPerSecond': self.rows_per_second}) + testDataSpec = dg.DataGenerator( + sparkSession=spark, + name="test_data_set1", + rows=self.row_count, + partitions=4, + seedMethod='hash_fieldname', + ) + + testDataSpec = ( + testDataSpec.withIdOutput() + .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", numColumns=self.column_count) + .withColumn("code1", IntegerType(), minValue=100, maxValue=200) + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) + + dfTestData = testDataSpec.build(withStreaming=True, options={'rowsPerSecond': self.rows_per_second}) # check that seed column is in schema fields = [c.name for c in dfTestData.schema.fields] @@ -65,13 +73,13 @@ def test_streaming(self, getStreamingDirs, seedColumnName): assert seedColumnName in fields assert "id" not in fields if seedColumnName != "id" else True - sq = (dfTestData - .writeStream - .format("csv") - .outputMode("append") - .option("path", test_dir) - .option("checkpointLocation", f"{checkpoint_dir}/{uuid.uuid4()}") - .start()) + sq = ( + dfTestData.writeStream.format("csv") + .outputMode("append") + .option("path", test_dir) + .option("checkpointLocation", f"{checkpoint_dir}/{uuid.uuid4()}") + .start() + ) # loop until we get one seconds worth of data start_time = time.time() @@ -109,27 +117,34 @@ def test_streaming_trigger_once(self, getStreamingDirs, seedColumnName): base_dir, test_dir, checkpoint_dir = getStreamingDirs if seedColumnName is not None: - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=self.row_count, - partitions=4, seedMethod='hash_fieldname', - seedColumnName=seedColumnName)) + testDataSpec = dg.DataGenerator( + sparkSession=spark, + name="test_data_set1", + rows=self.row_count, + partitions=4, + seedMethod='hash_fieldname', + seedColumnName=seedColumnName, + ) else: - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=self.row_count, - partitions=4, seedMethod='hash_fieldname')) - - testDataSpec = (testDataSpec - .withIdOutput() - .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", - numColumns=self.column_count) - .withColumn("code1", IntegerType(), minValue=100, maxValue=200) - .withColumn("code2", IntegerType(), minValue=0, maxValue=10) - .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - - ) - - dfTestData = testDataSpec.build(withStreaming=True, - options={'rowsPerSecond': self.rows_per_second}) + testDataSpec = dg.DataGenerator( + sparkSession=spark, + name="test_data_set1", + rows=self.row_count, + partitions=4, + seedMethod='hash_fieldname', + ) + + testDataSpec = ( + testDataSpec.withIdOutput() + .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", numColumns=self.column_count) + .withColumn("code1", IntegerType(), minValue=100, maxValue=200) + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) + + dfTestData = testDataSpec.build(withStreaming=True, options={'rowsPerSecond': self.rows_per_second}) # check that seed column is in schema fields = [c.name for c in dfTestData.schema.fields] @@ -145,14 +160,14 @@ def test_streaming_trigger_once(self, getStreamingDirs, seedColumnName): time_limit = 10.0 while elapsed_time < time_limit and rows_retrieved < self.rows_per_second: - sq = (dfTestData - .writeStream - .format("csv") - .outputMode("append") - .option("path", test_dir) - .option("checkpointLocation", checkpoint_dir) - .trigger(once=True) - .start()) + sq = ( + dfTestData.writeStream.format("csv") + .outputMode("append") + .option("path", test_dir) + .option("checkpointLocation", checkpoint_dir) + .trigger(once=True) + .start() + ) # wait for trigger once to terminate sq.awaitTermination(5) diff --git a/tests/test_text_generation.py b/tests/test_text_generation.py index 05a96772..8265ee84 100644 --- a/tests/test_text_generation.py +++ b/tests/test_text_generation.py @@ -10,23 +10,25 @@ import dbldatagen as dg from dbldatagen import TemplateGenerator, TextGenerator -schema = StructType([ - StructField("PK1", StringType(), True), - StructField("LAST_MODIFIED_UTC", TimestampType(), True), - StructField("date", DateType(), True), - StructField("str1", StringType(), True), - StructField("nint", IntegerType(), True), - StructField("nstr1", StringType(), True), - StructField("nstr2", StringType(), True), - StructField("nstr3", StringType(), True), - StructField("nstr4", StringType(), True), - StructField("nstr5", StringType(), True), - StructField("nstr6", StringType(), True), - StructField("email", StringType(), True), - StructField("ip_addr", StringType(), True), - StructField("phone", StringType(), True), - StructField("isDeleted", BooleanType(), True) -]) +schema = StructType( + [ + StructField("PK1", StringType(), True), + StructField("LAST_MODIFIED_UTC", TimestampType(), True), + StructField("date", DateType(), True), + StructField("str1", StringType(), True), + StructField("nint", IntegerType(), True), + StructField("nstr1", StringType(), True), + StructField("nstr2", StringType(), True), + StructField("nstr3", StringType(), True), + StructField("nstr4", StringType(), True), + StructField("nstr5", StringType(), True), + StructField("nstr6", StringType(), True), + StructField("email", StringType(), True), + StructField("ip_addr", StringType(), True), + StructField("phone", StringType(), True), + StructField("isDeleted", BooleanType(), True), + ] +) # add the following if using pandas udfs # .config("spark.sql.execution.arrow.maxRecordsPerBatch", "1000") \ @@ -62,22 +64,24 @@ def test_text_generator_basics(self): random_values1 = rng1.integers(10, 20, dtype=np.int32) assert 10 <= random_values1 <= 20 - @pytest.mark.parametrize("template, escapeSpecial, low, high, useSystemLib", - [ - (r'\n.\n.\n.\n', False, 0, 15, False), - (r'\n.\n.\n.\n', False, 20, 35, False), - (r'\n.\n.\n.\n', False, 15, None, False), - (r'\n.\n.\n.\n', False, 15, -1, False), - (r'\n.\n.\n.\n', False, 0, 15, True), - (r'\n.\n.\n.\n', False, 20, 35, True), - (r'\n.\n.\n.\n', False, 15, None, True), - (r'\n.\n.\n.\n', False, 15, -1, True), - ]) - def test_random_number_generator(self, template, escapeSpecial, low, high, useSystemLib): \ - # pylint: disable=too-many-positional-arguments - - """ As the test coverage tools dont detect code only used in UDFs, - lets add some explicit tests for the underlying code""" + @pytest.mark.parametrize( + "template, escapeSpecial, low, high, useSystemLib", + [ + (r'\n.\n.\n.\n', False, 0, 15, False), + (r'\n.\n.\n.\n', False, 20, 35, False), + (r'\n.\n.\n.\n', False, 15, None, False), + (r'\n.\n.\n.\n', False, 15, -1, False), + (r'\n.\n.\n.\n', False, 0, 15, True), + (r'\n.\n.\n.\n', False, 20, 35, True), + (r'\n.\n.\n.\n', False, 15, None, True), + (r'\n.\n.\n.\n', False, 15, -1, True), + ], + ) + def test_random_number_generator( + self, template, escapeSpecial, low, high, useSystemLib + ): # pylint: disable=too-many-positional-arguments + """As the test coverage tools dont detect code only used in UDFs, + lets add some explicit tests for the underlying code""" test_template = TemplateGenerator(template, escapeSpecialChars=escapeSpecial) rng1 = test_template.getNPRandomGenerator() @@ -100,13 +104,16 @@ def test_random_number_generator(self, template, escapeSpecial, low, high, useSy assert low <= random_value <= high - @pytest.mark.parametrize("template, escapeSpecial, expectedTemplates", - [(r'\n.\n.\n.\n', True, 1), - (r'(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd', False, 3), - (r'(\d\d\d)-\d\d\d-\d\d\d\d|1(\d\d\d) \d\d\d-\d\d\d\d|\d\d\d \d\d\d\d\d\d\d', True, 3), - (r'\dr_\v', False, 1), - (r'\w.\w@\w.com|\w@\w.co.u\k', False, 2), - ]) + @pytest.mark.parametrize( + "template, escapeSpecial, expectedTemplates", + [ + (r'\n.\n.\n.\n', True, 1), + (r'(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd', False, 3), + (r'(\d\d\d)-\d\d\d-\d\d\d\d|1(\d\d\d) \d\d\d-\d\d\d\d|\d\d\d \d\d\d\d\d\d\d', True, 3), + (r'\dr_\v', False, 1), + (r'\w.\w@\w.com|\w@\w.co.u\k', False, 2), + ], + ) def test_template_generator_properties(self, template, escapeSpecial, expectedTemplates): test_template = TemplateGenerator(template, escapeSpecialChars=escapeSpecial) @@ -120,24 +127,24 @@ def test_template_generator_properties(self, template, escapeSpecial, expectedTe assert len(test_template.templates) == expectedTemplates def test_simple_data_template(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=self.row_count, - partitions=self.partitions_requested) - .withSchema(schema) - .withIdOutput() - .withColumnSpec("date", percentNulls=0.1) - .withColumnSpec("nint", percentNulls=0.1, minValue=1, maxValue=9, step=2) - .withColumnSpec("nstr1", percentNulls=0.1, minValue=1, maxValue=9, step=2) - .withColumnSpec("nstr2", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, - format="%04f") - .withColumnSpec("nstr3", minValue=1.0, maxValue=9.0, step=2.0) - .withColumnSpec("nstr4", percentNulls=0.1, minValue=1, maxValue=9, step=2, format="%04d") - .withColumnSpec("nstr5", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True) - .withColumnSpec("nstr6", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True, - format="%04f") - .withColumnSpec("email", template=r'\w.\w@\w.com|\w@\w.co.u\k') - .withColumnSpec("ip_addr", template=r'\n.\n.\n.\n') - .withColumnSpec("phone", template=r'(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd') - ) + testDataSpec = ( + dg.DataGenerator( + sparkSession=spark, name="test_data_set1", rows=self.row_count, partitions=self.partitions_requested + ) + .withSchema(schema) + .withIdOutput() + .withColumnSpec("date", percentNulls=0.1) + .withColumnSpec("nint", percentNulls=0.1, minValue=1, maxValue=9, step=2) + .withColumnSpec("nstr1", percentNulls=0.1, minValue=1, maxValue=9, step=2) + .withColumnSpec("nstr2", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, format="%04f") + .withColumnSpec("nstr3", minValue=1.0, maxValue=9.0, step=2.0) + .withColumnSpec("nstr4", percentNulls=0.1, minValue=1, maxValue=9, step=2, format="%04d") + .withColumnSpec("nstr5", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True) + .withColumnSpec("nstr6", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True, format="%04f") + .withColumnSpec("email", template=r'\w.\w@\w.com|\w@\w.co.u\k') + .withColumnSpec("ip_addr", template=r'\n.\n.\n.\n') + .withColumnSpec("phone", template=r'(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd') + ) df_template_data = testDataSpec.build().cache() @@ -146,7 +153,7 @@ def test_simple_data_template(self): counts = df_template_data.agg( F.countDistinct("email").alias("email_count"), F.countDistinct("ip_addr").alias("ip_addr_count"), - F.countDistinct("phone").alias("phone_count") + F.countDistinct("phone").alias("phone_count"), ).collect()[0] assert counts['email_count'] >= 10 @@ -164,31 +171,31 @@ def test_simple_data_template(self): assert phone_patt.match(r["phone"]), "check phone" def test_large_template_driven_data_generation(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000000, - partitions=24) - .withSchema(schema) - .withIdOutput() - .withColumnSpec("date", percent_nulls=0.1) - .withColumnSpec("nint", percent_nulls=0.1, minValue=1, maxValue=9, step=2) - .withColumnSpec("nstr1", percent_nulls=0.1, minValue=1, maxValue=9, step=2) - .withColumnSpec("nstr2", percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, - format="%04f") - .withColumnSpec("nstr3", minValue=1.0, maxValue=9.0, step=2.0) - .withColumnSpec("nstr4", percent_nulls=0.1, minValue=1, maxValue=9, step=2, format="%04d") - .withColumnSpec("nstr5", percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True) - .withColumnSpec("nstr6", percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True, - format="%04f") - .withColumnSpec("email", template=r'\w.\w@\w.com|\w@\w.co.u\k') - .withColumnSpec("ip_addr", template=r'\n.\n.\n.\n') - .withColumnSpec("phone", template=r'(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd') - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000000, partitions=24) + .withSchema(schema) + .withIdOutput() + .withColumnSpec("date", percent_nulls=0.1) + .withColumnSpec("nint", percent_nulls=0.1, minValue=1, maxValue=9, step=2) + .withColumnSpec("nstr1", percent_nulls=0.1, minValue=1, maxValue=9, step=2) + .withColumnSpec("nstr2", percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, format="%04f") + .withColumnSpec("nstr3", minValue=1.0, maxValue=9.0, step=2.0) + .withColumnSpec("nstr4", percent_nulls=0.1, minValue=1, maxValue=9, step=2, format="%04d") + .withColumnSpec("nstr5", percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True) + .withColumnSpec( + "nstr6", percent_nulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True, format="%04f" + ) + .withColumnSpec("email", template=r'\w.\w@\w.com|\w@\w.co.u\k') + .withColumnSpec("ip_addr", template=r'\n.\n.\n.\n') + .withColumnSpec("phone", template=r'(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd') + ) df_template_data = testDataSpec.build() counts = df_template_data.agg( F.countDistinct("email").alias("email_count"), F.countDistinct("ip_addr").alias("ip_addr_count"), - F.countDistinct("phone").alias("phone_count") + F.countDistinct("phone").alias("phone_count"), ).collect()[0] assert counts['email_count'] >= 100 @@ -196,8 +203,8 @@ def test_large_template_driven_data_generation(self): assert counts['phone_count'] >= 100 def test_raw_iltext_text_generation(self): - """ As the test coverage tools dont detect code only used in UDFs, - lets add some explicit tests for the underlying code""" + """As the test coverage tools dont detect code only used in UDFs, + lets add some explicit tests for the underlying code""" # test the IL Text generator tg1 = dg.ILText(paragraphs=(1, 4), sentences=(2, 6), words=(1, 8)) @@ -219,29 +226,24 @@ def test_raw_iltext_text_generation(self): assert match_pattern.match(test_value) def test_large_ILText_driven_data_generation(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000000, - partitions=8) - .withSchema(schema) - .withIdOutput() - .withColumnSpec("date", percentNulls=0.1) - .withColumnSpec("nint", percentNulls=0.1, minValue=1, maxValue=9, step=2) - .withColumnSpec("nstr1", percentNulls=0.1, minValue=1, maxValue=9, step=2) - .withColumnSpec("nstr2", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, - format="%04f") - .withColumnSpec("nstr3", minValue=1.0, maxValue=9.0, step=2.0) - .withColumnSpec("nstr4", percentNulls=0.1, minValue=1, maxValue=9, step=2, format="%04d") - .withColumnSpec("nstr5", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True) - .withColumnSpec("nstr6", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True, - format="%04f") - .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6), words=(1, 8))) - - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=1000000, partitions=8) + .withSchema(schema) + .withIdOutput() + .withColumnSpec("date", percentNulls=0.1) + .withColumnSpec("nint", percentNulls=0.1, minValue=1, maxValue=9, step=2) + .withColumnSpec("nstr1", percentNulls=0.1, minValue=1, maxValue=9, step=2) + .withColumnSpec("nstr2", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, format="%04f") + .withColumnSpec("nstr3", minValue=1.0, maxValue=9.0, step=2.0) + .withColumnSpec("nstr4", percentNulls=0.1, minValue=1, maxValue=9, step=2, format="%04d") + .withColumnSpec("nstr5", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True) + .withColumnSpec("nstr6", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True, format="%04f") + .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6), words=(1, 8))) + ) df_template_data = testDataSpec.build() - counts = df_template_data.agg( - F.countDistinct("paras").alias("paragraphs_count") - ).collect()[0] + counts = df_template_data.agg(F.countDistinct("paras").alias("paragraphs_count")).collect()[0] assert counts['paragraphs_count'] >= 100 @@ -250,7 +252,7 @@ def test_large_ILText_driven_data_generation(self): counts = df_template_data.agg( F.countDistinct("email").alias("email_count"), F.countDistinct("ip_addr").alias("ip_addr_count"), - F.countDistinct("phone").alias("phone_count") + F.countDistinct("phone").alias("phone_count"), ).collect()[0] assert counts['email_count'] >= 100 @@ -258,29 +260,24 @@ def test_large_ILText_driven_data_generation(self): assert counts['phone_count'] >= 100 def test_small_ILText_driven_data_generation(self): - testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, - partitions=8) - .withSchema(schema) - .withIdOutput() - .withColumnSpec("date", percentNulls=0.1) - .withColumnSpec("nint", percentNulls=0.1, minValue=1, maxValue=9, step=2) - .withColumnSpec("nstr1", percentNulls=0.1, minValue=1, maxValue=9, step=2) - .withColumnSpec("nstr2", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, - format="%04f") - .withColumnSpec("nstr3", minValue=1.0, maxValue=9.0, step=2.0) - .withColumnSpec("nstr4", percentNulls=0.1, minValue=1, maxValue=9, step=2, format="%04d") - .withColumnSpec("nstr5", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True) - .withColumnSpec("nstr6", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True, - format="%04f") - .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6), words=(1, 8))) - - ) + testDataSpec = ( + dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=100000, partitions=8) + .withSchema(schema) + .withIdOutput() + .withColumnSpec("date", percentNulls=0.1) + .withColumnSpec("nint", percentNulls=0.1, minValue=1, maxValue=9, step=2) + .withColumnSpec("nstr1", percentNulls=0.1, minValue=1, maxValue=9, step=2) + .withColumnSpec("nstr2", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, format="%04f") + .withColumnSpec("nstr3", minValue=1.0, maxValue=9.0, step=2.0) + .withColumnSpec("nstr4", percentNulls=0.1, minValue=1, maxValue=9, step=2, format="%04d") + .withColumnSpec("nstr5", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True) + .withColumnSpec("nstr6", percentNulls=0.1, minValue=1.5, maxValue=2.5, step=0.3, random=True, format="%04f") + .withColumn("paras", text=dg.ILText(paragraphs=(1, 4), sentences=(2, 6), words=(1, 8))) + ) df_iltext_data = testDataSpec.build() - counts = df_iltext_data.agg( - F.countDistinct("paras").alias("paragraphs_count") - ).collect()[0] + counts = df_iltext_data.agg(F.countDistinct("paras").alias("paragraphs_count")).collect()[0] assert counts['paragraphs_count'] >= 10 @@ -293,18 +290,27 @@ def test_small_ILText_driven_data_generation(self): assert test_value is not None assert match_pattern.match(test_value) - @pytest.mark.parametrize("template, expectedOutput, escapeSpecial", - [(r'\n.\n.\n.\n', r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+", False), - (r'(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd', - r"(\([0-9]+\)-[0-9]+-[0-9]+)|(1\([0-9]+\) [0-9]+-[0-9]+)|([0-9]+ [0-9]+)", False), - (r'(\d\d\d)-\d\d\d-\d\d\d\d|1(\d\d\d) \d\d\d-\d\d\d\d|\d\d\d \d\d\d\d\d\d\d', - r"(\([0-9]+\)-[0-9]+-[0-9]+)|(1\([0-9]+\) [0-9]+-[0-9]+)|([0-9]+ [0-9]+)", True), - (r'\dr_\v', r"dr_[0-9]+", False), - (r'\w.\w@\w.com|\w@\w.co.u\k', r"[a-z\.\@]+", False), - ]) + @pytest.mark.parametrize( + "template, expectedOutput, escapeSpecial", + [ + (r'\n.\n.\n.\n', r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+", False), + ( + r'(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd', + r"(\([0-9]+\)-[0-9]+-[0-9]+)|(1\([0-9]+\) [0-9]+-[0-9]+)|([0-9]+ [0-9]+)", + False, + ), + ( + r'(\d\d\d)-\d\d\d-\d\d\d\d|1(\d\d\d) \d\d\d-\d\d\d\d|\d\d\d \d\d\d\d\d\d\d', + r"(\([0-9]+\)-[0-9]+-[0-9]+)|(1\([0-9]+\) [0-9]+-[0-9]+)|([0-9]+ [0-9]+)", + True, + ), + (r'\dr_\v', r"dr_[0-9]+", False), + (r'\w.\w@\w.com|\w@\w.co.u\k', r"[a-z\.\@]+", False), + ], + ) def test_raw_template_text_generation1(self, template, expectedOutput, escapeSpecial): - """ As the test coverage tools dont detect code only used in UDFs, - lets add some explicit tests for the underlying code""" + """As the test coverage tools dont detect code only used in UDFs, + lets add some explicit tests for the underlying code""" match_pattern = re.compile(expectedOutput) test_template = TemplateGenerator(template, escapeSpecialChars=escapeSpecial) @@ -319,8 +325,8 @@ def check_pattern(x): assert match_pattern.match(test_value), f"pattern '{expectedOutput}' doesn't match result '{test_value}'" def test_raw_template_text_generation3(self): - """ As the test coverage tools don't detect code only used in UDFs, - lets add some explicit tests for the underlying code""" + """As the test coverage tools don't detect code only used in UDFs, + lets add some explicit tests for the underlying code""" pattern = r'\w.\w@\w.com|\w@\w.co.u\k' match_pattern = re.compile(r"[a-z\.\@]+") test_template = TemplateGenerator(pattern) @@ -332,19 +338,23 @@ def test_raw_template_text_generation3(self): assert match_pattern.match(test_value) def test_simple_data2(self): - testDataSpec2 = (dg.DataGenerator(sparkSession=spark, name="test_data_set2", rows=self.row_count, - partitions=self.partitions_requested) - .withSchema(schema) - .withIdOutput() - .withColumnSpec("date", percent_nulls=0.1) - .withColumnSpecs(patterns="n.*", match_types=StringType(), - percent_nulls=0.1, minValue=1, maxValue=9, step=2) - .withColumnSpecs(patterns="n.*", match_types=IntegerType(), - percent_nulls=0.1, minValue=1, maxValue=200, step=-2) - .withColumnSpec("email", template=r'\w.\w@\w.com|\w@\w.co.u\k') - .withColumnSpec("ip_addr", template=r'\n.\n.\n.\n') - .withColumnSpec("phone", template=r'(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd') - ) + testDataSpec2 = ( + dg.DataGenerator( + sparkSession=spark, name="test_data_set2", rows=self.row_count, partitions=self.partitions_requested + ) + .withSchema(schema) + .withIdOutput() + .withColumnSpec("date", percent_nulls=0.1) + .withColumnSpecs( + patterns="n.*", match_types=StringType(), percent_nulls=0.1, minValue=1, maxValue=9, step=2 + ) + .withColumnSpecs( + patterns="n.*", match_types=IntegerType(), percent_nulls=0.1, minValue=1, maxValue=200, step=-2 + ) + .withColumnSpec("email", template=r'\w.\w@\w.com|\w@\w.co.u\k') + .withColumnSpec("ip_addr", template=r'\n.\n.\n.\n') + .withColumnSpec("phone", template=r'(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd') + ) testDataSpec2.build().show() df_template_data = testDataSpec2.build() @@ -352,7 +362,7 @@ def test_simple_data2(self): counts = df_template_data.agg( F.countDistinct("email").alias("email_count"), F.countDistinct("ip_addr").alias("ip_addr_count"), - F.countDistinct("phone").alias("phone_count") + F.countDistinct("phone").alias("phone_count"), ).collect()[0] assert counts['email_count'] >= 100 @@ -360,86 +370,124 @@ def test_simple_data2(self): assert counts['phone_count'] >= 100 def test_multi_columns(self): - testDataSpec3 = (dg.DataGenerator(sparkSession=spark, name="test_data_set3", rows=self.row_count, - partitions=self.partitions_requested, verbose=True) - .withIdOutput() - .withColumn("val1", IntegerType(), percentNulls=0.1) - .withColumn("val2", IntegerType(), percentNulls=0.1) - .withColumn("val3", StringType(), baseColumn=["val1", "val2"], baseColumnType="values", - template=r"\v-1") - ) + testDataSpec3 = ( + dg.DataGenerator( + sparkSession=spark, + name="test_data_set3", + rows=self.row_count, + partitions=self.partitions_requested, + verbose=True, + ) + .withIdOutput() + .withColumn("val1", IntegerType(), percentNulls=0.1) + .withColumn("val2", IntegerType(), percentNulls=0.1) + .withColumn("val3", StringType(), baseColumn=["val1", "val2"], baseColumnType="values", template=r"\v-1") + ) testDataSpec3.build().show() def test_multi_columns2(self): - testDataSpec4 = (dg.DataGenerator(sparkSession=spark, name="test_data_set3", rows=self.row_count, - partitions=self.partitions_requested, verbose=True) - .withIdOutput() - .withColumn("val1", IntegerType(), percentNulls=0.1) - .withColumn("val2", IntegerType(), percentNulls=0.1) - .withColumn("val3", StringType(), baseColumn=["val1", "val2"], baseColumnType="values", - template=r"\v0-\v1") - ) + testDataSpec4 = ( + dg.DataGenerator( + sparkSession=spark, + name="test_data_set3", + rows=self.row_count, + partitions=self.partitions_requested, + verbose=True, + ) + .withIdOutput() + .withColumn("val1", IntegerType(), percentNulls=0.1) + .withColumn("val2", IntegerType(), percentNulls=0.1) + .withColumn("val3", StringType(), baseColumn=["val1", "val2"], baseColumnType="values", template=r"\v0-\v1") + ) # in this case we expect values of the form ` - ` testDataSpec4.build().show() def test_multi_columns3(self): - testDataSpec5 = (dg.DataGenerator(sparkSession=spark, name="test_data_set3", rows=self.row_count, - partitions=self.partitions_requested, verbose=True) - .withIdOutput() - .withColumn("val1", IntegerType(), percentNulls=0.1) - .withColumn("val2", IntegerType(), percentNulls=0.1) - .withColumn("val3", StringType(), baseColumn=["val1", "val2"], baseColumnType="values", - template=r"\v\0-\v\1") - ) + testDataSpec5 = ( + dg.DataGenerator( + sparkSession=spark, + name="test_data_set3", + rows=self.row_count, + partitions=self.partitions_requested, + verbose=True, + ) + .withIdOutput() + .withColumn("val1", IntegerType(), percentNulls=0.1) + .withColumn("val2", IntegerType(), percentNulls=0.1) + .withColumn( + "val3", StringType(), baseColumn=["val1", "val2"], baseColumnType="values", template=r"\v\0-\v\1" + ) + ) # in this case we expect values of the form `[ array of values]0 - [array of values]1` testDataSpec5.build().show() def test_multi_columns4(self): - testDataSpec6 = (dg.DataGenerator(sparkSession=spark, name="test_data_set3", rows=self.row_count, - partitions=self.partitions_requested, verbose=True) - .withIdOutput() - .withColumn("val1", IntegerType(), percentNulls=0.1) - .withColumn("val2", IntegerType(), percentNulls=0.1) - .withColumn("val3", StringType(), baseColumn=["val1", "val2"], baseColumnType="hash", - template=r"\v0-\v1") - ) + testDataSpec6 = ( + dg.DataGenerator( + sparkSession=spark, + name="test_data_set3", + rows=self.row_count, + partitions=self.partitions_requested, + verbose=True, + ) + .withIdOutput() + .withColumn("val1", IntegerType(), percentNulls=0.1) + .withColumn("val2", IntegerType(), percentNulls=0.1) + .withColumn("val3", StringType(), baseColumn=["val1", "val2"], baseColumnType="hash", template=r"\v0-\v1") + ) # here the values for val3 are undefined as the base value for the column is a hash of the base columns testDataSpec6.build().show() def test_multi_columns5(self): - testDataSpec7 = (dg.DataGenerator(sparkSession=spark, name="test_data_set3", rows=self.row_count, - partitions=self.partitions_requested, verbose=True) - .withIdOutput() - .withColumn("val1", IntegerType(), percentNulls=0.1) - .withColumn("val2", IntegerType(), percentNulls=0.1) - .withColumn("val3", StringType(), baseColumn=["val1", "val2"], baseColumnType="hash", - format="%s") - ) + testDataSpec7 = ( + dg.DataGenerator( + sparkSession=spark, + name="test_data_set3", + rows=self.row_count, + partitions=self.partitions_requested, + verbose=True, + ) + .withIdOutput() + .withColumn("val1", IntegerType(), percentNulls=0.1) + .withColumn("val2", IntegerType(), percentNulls=0.1) + .withColumn("val3", StringType(), baseColumn=["val1", "val2"], baseColumnType="hash", format="%s") + ) # here the values for val3 are undefined as the base value for the column is a hash of the base columns testDataSpec7.build().show() def test_multi_columns6(self): - testDataSpec8 = (dg.DataGenerator(sparkSession=spark, name="test_data_set3", rows=self.row_count, - partitions=self.partitions_requested, verbose=True) - .withIdOutput() - .withColumn("val1", IntegerType(), percentNulls=0.1) - .withColumn("val2", IntegerType(), percentNulls=0.1) - .withColumn("val3", StringType(), baseColumn=["val1", "val2"], baseColumnType="values", - format="%s") - ) + testDataSpec8 = ( + dg.DataGenerator( + sparkSession=spark, + name="test_data_set3", + rows=self.row_count, + partitions=self.partitions_requested, + verbose=True, + ) + .withIdOutput() + .withColumn("val1", IntegerType(), percentNulls=0.1) + .withColumn("val2", IntegerType(), percentNulls=0.1) + .withColumn("val3", StringType(), baseColumn=["val1", "val2"], baseColumnType="values", format="%s") + ) # here the values for val3 are undefined as the base value for the column is a hash of the base columns testDataSpec8.build().show() def test_multi_columns7(self): - testDataSpec9 = (dg.DataGenerator(sparkSession=spark, name="test_data_set3", rows=self.row_count, - partitions=self.partitions_requested, verbose=True) - .withIdOutput() - .withColumn("val1", IntegerType(), percentNulls=0.1) - .withColumn("val2", IntegerType(), percentNulls=0.1) - .withColumn("val3", StringType(), baseColumn=["val1", "val2"], format="%s") - ) + testDataSpec9 = ( + dg.DataGenerator( + sparkSession=spark, + name="test_data_set3", + rows=self.row_count, + partitions=self.partitions_requested, + verbose=True, + ) + .withIdOutput() + .withColumn("val1", IntegerType(), percentNulls=0.1) + .withColumn("val2", IntegerType(), percentNulls=0.1) + .withColumn("val3", StringType(), baseColumn=["val1", "val2"], format="%s") + ) # here the values for val3 are undefined as the base value for the column is a hash of the base columns testDataSpec9.build().show() @@ -453,14 +501,12 @@ def test_prefix(self): .withColumn("code3", "integer", minValue=1, maxValue=20, step=1) .withColumn("code4", "integer", minValue=1, maxValue=20, step=1) # base column specifies dependent column - .withColumn("site_cd", "string", prefix='site', baseColumn='code1') .withColumn("device_status", "string", minValue=1, maxValue=200, step=1, prefix='status', random=True) - .withColumn("site_cd2", "string", prefix='site', baseColumn='code1', text_separator=":") - .withColumn("device_status2", "string", minValue=1, maxValue=200, step=1, - prefix='status', text_separator=":") - + .withColumn( + "device_status2", "string", minValue=1, maxValue=200, step=1, prefix='status', text_separator=":" + ) ) df = testdata_generator.build() # build our dataset @@ -512,13 +558,12 @@ def test_suffix(self): .withColumn("code3", "integer", minValue=1, maxValue=20, step=1) .withColumn("code4", "integer", minValue=1, maxValue=20, step=1) # base column specifies dependent column - .withColumn("site_cd", "string", suffix='site', baseColumn='code1') .withColumn("device_status", "string", minValue=1, maxValue=200, step=1, suffix='status', random=True) - .withColumn("site_cd2", "string", suffix='site', baseColumn='code1', text_separator=":") - .withColumn("device_status2", "string", minValue=1, maxValue=200, step=1, - suffix='status', text_separator=":") + .withColumn( + "device_status2", "string", minValue=1, maxValue=200, step=1, suffix='status', text_separator=":" + ) ) df = testdata_generator.build() # build our dataset @@ -555,11 +600,17 @@ def test_prefix_and_suffix(self): # base column specifies dependent column .withColumn("site_cd", "string", suffix='site', baseColumn='code1', prefix="test") .withColumn("device_status", "string", minValue=1, maxValue=200, step=1, suffix='status', prefix="test") - .withColumn("site_cd2", "string", suffix='site', baseColumn='code1', text_separator=":", prefix="test") - .withColumn("device_status2", "string", minValue=1, maxValue=200, step=1, - suffix='status', text_separator=":", - prefix="test") + .withColumn( + "device_status2", + "string", + minValue=1, + maxValue=200, + step=1, + suffix='status', + text_separator=":", + prefix="test", + ) ) df = testdata_generator.build() # build our dataset diff --git a/tests/test_text_generator_basic.py b/tests/test_text_generator_basic.py index d3a5d8f4..daaf59af 100644 --- a/tests/test_text_generator_basic.py +++ b/tests/test_text_generator_basic.py @@ -14,7 +14,7 @@ class TestTextGeneratorBasic: partitions_requested = 4 class TestTextGenerator(TextGenerator): - def pandasGenerateText(self, v): # pylint: disable=useless-parent-delegation + def pandasGenerateText(self, v): # pylint: disable=useless-parent-delegation return super().pandasGenerateText(v) @pytest.mark.parametrize("randomSeed", [None, 0, -1, 2112, 42]) @@ -39,11 +39,21 @@ def test_base_textgenerator_raises_error(self): text_gen1 = self.TestTextGenerator() text_gen1.pandasGenerateText(None) - @pytest.mark.parametrize("randomSeed, forceNewInstance", [(None, True), (None, False), - (0, True), (0, False), - (-1, True), (-1, False), - (2112, True), (2112, False), - (42, True), (42, False)]) + @pytest.mark.parametrize( + "randomSeed, forceNewInstance", + [ + (None, True), + (None, False), + (0, True), + (0, False), + (-1, True), + (-1, False), + (2112, True), + (2112, False), + (42, True), + (42, False), + ], + ) def test_text_generator_rng(self, randomSeed, forceNewInstance): text_gen1 = self.TestTextGenerator() text_gen2 = self.TestTextGenerator() @@ -74,21 +84,29 @@ def test_text_generator_rng(self, randomSeed, forceNewInstance): if randomSeed is not None and randomSeed != -1 and forceNewInstance: assert (values1 == values2).all() - @pytest.mark.parametrize("values, expectedType", [([1, 2, 3], np.uint8), - ([1, 40000, 3], np.uint16), - ([1, 40000.0, 3], np.uint16), - ([1, 40000.4, 3], np.uint16), - (np.array([1, 40000.4, 3]), np.uint16) - ]) + @pytest.mark.parametrize( + "values, expectedType", + [ + ([1, 2, 3], np.uint8), + ([1, 40000, 3], np.uint16), + ([1, 40000.0, 3], np.uint16), + ([1, 40000.4, 3], np.uint16), + (np.array([1, 40000.4, 3]), np.uint16), + ], + ) def test_text_generator_compact_types(self, values, expectedType): text_gen1 = self.TestTextGenerator() np_type = text_gen1.compactNumpyTypeForValues(values) assert np_type == expectedType - @pytest.mark.parametrize("template, expectedOutput", [(r'53.123.ddd.ddd', r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+"), - (r'\w.\W.\w.\W', r"[a-z]+\.[A-Z]+\.[a-z]+\.[A-Z]+"), - ]) + @pytest.mark.parametrize( + "template, expectedOutput", + [ + (r'53.123.ddd.ddd', r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+"), + (r'\w.\W.\w.\W', r"[a-z]+\.[A-Z]+\.[a-z]+\.[A-Z]+"), + ], + ) def test_templates_constant(self, template, expectedOutput): text_gen1 = TemplateGenerator(template) text_gen1 = text_gen1.withRandomSeed(2112) @@ -102,17 +120,16 @@ def test_templates_constant(self, template, expectedOutput): for x in results[0:100]: assert patt.match(x), f"Expecting data '{x}' to match pattern {patt}" - @pytest.mark.parametrize("sourceData, template, expectedOutput", [(np.arange(100000), - r'53.123.\V.\V', r"53\.123\.105\.105"), - (np.arange(100000), - r'\V.\V.\V.123', r"105\.105\.105\.123"), - (np.arange(1000), - r'\V.\W.\w.\W', r"105\.[A-Z]+\.[a-z]+\.[A-Z]+"), - ([[x, x + 1] for x in np.arange(10000)], - r'\v0.\v1.\w.\W', r"105\.106\.[a-z]+\.[A-Z]+"), - ([(x, x + 1) for x in range(10000)], - r'\v0.\v1.\w.\W', r"105\.106\.[a-z]+\.[A-Z]+"), - ]) + @pytest.mark.parametrize( + "sourceData, template, expectedOutput", + [ + (np.arange(100000), r'53.123.\V.\V', r"53\.123\.105\.105"), + (np.arange(100000), r'\V.\V.\V.123', r"105\.105\.105\.123"), + (np.arange(1000), r'\V.\W.\w.\W', r"105\.[A-Z]+\.[a-z]+\.[A-Z]+"), + ([[x, x + 1] for x in np.arange(10000)], r'\v0.\v1.\w.\W', r"105\.106\.[a-z]+\.[A-Z]+"), + ([(x, x + 1) for x in range(10000)], r'\v0.\v1.\w.\W', r"105\.106\.[a-z]+\.[A-Z]+"), + ], + ) def test_template_value_substitution(self, sourceData, template, expectedOutput): """ Test value substition for row 105 @@ -133,10 +150,14 @@ def test_template_value_substitution(self, sourceData, template, expectedOutput) test_row = results[105] assert patt.match(test_row), f"Expecting data '{test_row}' to match pattern {patt}" - @pytest.mark.parametrize("template, expectedRandomNumbers", [(r'53.123.ddd.ddd', 6), - (r'\w.\W.\w.\W', 4), - (r'\w.\W.\w.\W|\w \w|\W \w \W', [4, 2, 3]), - ]) + @pytest.mark.parametrize( + "template, expectedRandomNumbers", + [ + (r'53.123.ddd.ddd', 6), + (r'\w.\W.\w.\W', 4), + (r'\w.\W.\w.\W|\w \w|\W \w \W', [4, 2, 3]), + ], + ) def test_prepare_templates(self, template, expectedRandomNumbers): text_gen1 = TemplateGenerator(template) text_gen1 = text_gen1.withRandomSeed(2112) @@ -154,10 +175,14 @@ def test_prepare_templates(self, template, expectedRandomNumbers): assert len(vector_rnd) == expectedVectorSize, f"template is '{individualTemplate}'" assert placeholders > len(vector_rnd) - @pytest.mark.parametrize("template, expectedRandomNumbers", [(r'53.123.ddd.ddd', 6), - (r'\w.\W.\w.\W', 4), - (r'\w.\W.\w.\W|\w \w|\W \w \W', [4, 2, 3]), - ]) + @pytest.mark.parametrize( + "template, expectedRandomNumbers", + [ + (r'53.123.ddd.ddd', 6), + (r'\w.\W.\w.\W', 4), + (r'\w.\W.\w.\W|\w \w|\W \w \W', [4, 2, 3]), + ], + ) def test_prepare_bounds(self, template, expectedRandomNumbers): text_gen1 = TemplateGenerator(template) text_gen1 = text_gen1.withRandomSeed(2112) @@ -183,17 +208,20 @@ def test_prepare_bounds(self, template, expectedRandomNumbers): print(template_rnds) - @pytest.mark.parametrize("template, expectedOutput", [(r'53.123.ddd.ddd', r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+"), - (r'\w.\W.\w.\W', r"[a-z]+\.[A-Z]+\.[a-z]+\.[A-Z]+"), - (r'\w.\W.\w.\W|\w \w', r"[a-zA-Z\. ]+"), - (r'Dddd', r"[1-9][0-9]+"), - (r'\xxxxx \xXXX', r"x[0-9a-f]+ x[0-9A-F]+"), - (r'Aaaa Kkkk \N \n', r"[a-zA-Z]+ [0-9A-Za-z]+ [0-9]+ [0-9]+"), - (r'Aaaa \V \N \n', r"[A-Za-z]+ [0-9]+ [0-9]+ [0-9]+"), - ]) + @pytest.mark.parametrize( + "template, expectedOutput", + [ + (r'53.123.ddd.ddd', r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+"), + (r'\w.\W.\w.\W', r"[a-z]+\.[A-Z]+\.[a-z]+\.[A-Z]+"), + (r'\w.\W.\w.\W|\w \w', r"[a-zA-Z\. ]+"), + (r'Dddd', r"[1-9][0-9]+"), + (r'\xxxxx \xXXX', r"x[0-9a-f]+ x[0-9A-F]+"), + (r'Aaaa Kkkk \N \n', r"[a-zA-Z]+ [0-9A-Za-z]+ [0-9]+ [0-9]+"), + (r'Aaaa \V \N \n', r"[A-Za-z]+ [0-9]+ [0-9]+ [0-9]+"), + ], + ) def test_apply_template(self, template, expectedOutput): - """ This method tests the core logic of the template generator - """ + """This method tests the core logic of the template generator""" text_gen1 = TemplateGenerator(template) text_gen1 = text_gen1.withRandomSeed(2112) @@ -230,11 +258,7 @@ def test_apply_template(self, template, expectedOutput): np.ma.harden_mask(m) # expand values into placeholders - text_gen1._applyTemplateStringsForTemplate(pd_data, - selectedTemplate, - masked_placeholders, - masked_rnds - ) + text_gen1._applyTemplateStringsForTemplate(pd_data, selectedTemplate, masked_placeholders, masked_rnds) # soften mask, allowing modifications for m in masked_matrices: diff --git a/tests/test_text_templates.py b/tests/test_text_templates.py index f4edd1d2..0c554d14 100644 --- a/tests/test_text_templates.py +++ b/tests/test_text_templates.py @@ -26,23 +26,25 @@ class TestTextTemplates: row_count = 100000 partitions_requested = 4 - @pytest.mark.parametrize("templates, splitTemplates", - [ - (r"a|b", ['a', 'b']), - (r"a|b|", ['a', 'b', '']), - (r"a", ['a']), - (r"", ['']), - (r"a\|b", [r'a|b']), - (r"a\\|b", [r'a\\', 'b']), - (r"a\|b|c", [r'a|b', 'c']), - (r"123,$456|test test2 |\|\a\\a |021 \| 123", - ['123,$456', 'test test2 ', '|\\a\\\\a ', '021 | 123']), - ( - r"123 \\| 123 \|123 | 123|123|123 |asd023,\|23|", - ['123 \\\\', ' 123 |123 ', ' 123', '123', '123 ', 'asd023,|23', '']), - (r" 123|123|123 |asd023,\|23", [' 123', '123', '123 ', 'asd023,|23']), - (r'', ['']) - ]) + @pytest.mark.parametrize( + "templates, splitTemplates", + [ + (r"a|b", ['a', 'b']), + (r"a|b|", ['a', 'b', '']), + (r"a", ['a']), + (r"", ['']), + (r"a\|b", [r'a|b']), + (r"a\\|b", [r'a\\', 'b']), + (r"a\|b|c", [r'a|b', 'c']), + (r"123,$456|test test2 |\|\a\\a |021 \| 123", ['123,$456', 'test test2 ', '|\\a\\\\a ', '021 | 123']), + ( + r"123 \\| 123 \|123 | 123|123|123 |asd023,\|23|", + ['123 \\\\', ' 123 |123 ', ' 123', '123', '123 ', 'asd023,|23', ''], + ), + (r" 123|123|123 |asd023,\|23", [' 123', '123', '123 ', 'asd023,|23']), + (r'', ['']), + ], + ) def test_split_templates(self, templates, splitTemplates): tg1 = TemplateGenerator("test", escapeSpecialChars=False) @@ -50,66 +52,68 @@ def test_split_templates(self, templates, splitTemplates): assert results == splitTemplates - @pytest.mark.parametrize("templateProvided, escapeSpecial, useTemplateObject", - [ # (r'\w \w|\w \v. \w', False, False), - (r'A', False, True), - (r'D', False, True), - (r'K', False, True), - (r'X', False, True), - (r'\W', False, True), - (r'\W', True, True), - (r'\\w A. \\w|\\w \\w', False, False), - (r'\\w \\w|\\w A. \\w', False, False), - (r'\\w \\w|\\w A. \\w', False, True), - (r'\\w \\w|\\w A. \\w', True, True), - (r'\\w \\w|\\w A. \\w|\w n n \w', False, False), - (r'\\w \\w|\\w K. \\w', False, False), - (r'\\w \\w|\\w K. \\w', False, True), - (r'\\w \\w|\\w K. \\w', True, True), - (r'\\w \\w|\\w X. \\w', False, False), - (r'\\w \\w|\\w X. \\w', False, True), - (r'\\w \\w|\\w X. \\w', True, True), - (r'\\w \\w|\\w a. \\w', False, False), - (r'\\w \\w|\\w a. \\w', False, True), - (r'\\w \\w|\\w a. \\w', True, True), - (r'\\w \\w|\\w k. \\w', False, False), - (r'\\w \\w|\\w k. \\w', False, True), - (r'\\w \\w|\\w k. \\w', True, True), - (r'\\w \\w|\\w x. \\w', False, False), - (r'\\w \\w|\\w x. \\w', False, True), - (r'\\w \\w|\\w x. \\w', True, True), - (r'\\w a. \\w', False, True), - (r'\\w a. \\w|\\w \\w', False, False), - (r'\\w k. \\w', False, True), - (r'\\w k. \\w|\\w \\w', False, False), - (r'\n', False, True), - (r'\n', True, True), - (r'\v', False, True), - (r'\v', True, True), - (r'\w A. \w', False, False), - (r'\w \a. \w', True, True), - (r'\w \k. \w', True, True), - (r'\w \n \w', True, True), - (r'\w \w|\w A. \w', False, False), - (r'\w \w|\w \A. \w', True, True), - (r'\w \w|\w \a. \w', True, True), - (r'\w \w|\w \w \w|\w \n \w|\w \w \w \w', True, True), - (r'\w aAdDkK \w', False, False), - (r'\w aAdDkKxX \n \N \w', False, False), - (r'\w', False, False), - (r'\w', False, True), - (r'\w', True, True), - (r'a', False, True), - (r'b', False, False), - (r'b', False, True), - (r'b', True, True), - (r'd', False, True), - (r'k', False, True), - (r'x', False, True), - ('', False, False), - ('', False, True), - (r'', True, True), - ]) + @pytest.mark.parametrize( + "templateProvided, escapeSpecial, useTemplateObject", + [ # (r'\w \w|\w \v. \w', False, False), + (r'A', False, True), + (r'D', False, True), + (r'K', False, True), + (r'X', False, True), + (r'\W', False, True), + (r'\W', True, True), + (r'\\w A. \\w|\\w \\w', False, False), + (r'\\w \\w|\\w A. \\w', False, False), + (r'\\w \\w|\\w A. \\w', False, True), + (r'\\w \\w|\\w A. \\w', True, True), + (r'\\w \\w|\\w A. \\w|\w n n \w', False, False), + (r'\\w \\w|\\w K. \\w', False, False), + (r'\\w \\w|\\w K. \\w', False, True), + (r'\\w \\w|\\w K. \\w', True, True), + (r'\\w \\w|\\w X. \\w', False, False), + (r'\\w \\w|\\w X. \\w', False, True), + (r'\\w \\w|\\w X. \\w', True, True), + (r'\\w \\w|\\w a. \\w', False, False), + (r'\\w \\w|\\w a. \\w', False, True), + (r'\\w \\w|\\w a. \\w', True, True), + (r'\\w \\w|\\w k. \\w', False, False), + (r'\\w \\w|\\w k. \\w', False, True), + (r'\\w \\w|\\w k. \\w', True, True), + (r'\\w \\w|\\w x. \\w', False, False), + (r'\\w \\w|\\w x. \\w', False, True), + (r'\\w \\w|\\w x. \\w', True, True), + (r'\\w a. \\w', False, True), + (r'\\w a. \\w|\\w \\w', False, False), + (r'\\w k. \\w', False, True), + (r'\\w k. \\w|\\w \\w', False, False), + (r'\n', False, True), + (r'\n', True, True), + (r'\v', False, True), + (r'\v', True, True), + (r'\w A. \w', False, False), + (r'\w \a. \w', True, True), + (r'\w \k. \w', True, True), + (r'\w \n \w', True, True), + (r'\w \w|\w A. \w', False, False), + (r'\w \w|\w \A. \w', True, True), + (r'\w \w|\w \a. \w', True, True), + (r'\w \w|\w \w \w|\w \n \w|\w \w \w \w', True, True), + (r'\w aAdDkK \w', False, False), + (r'\w aAdDkKxX \n \N \w', False, False), + (r'\w', False, False), + (r'\w', False, True), + (r'\w', True, True), + (r'a', False, True), + (r'b', False, False), + (r'b', False, True), + (r'b', True, True), + (r'd', False, True), + (r'k', False, True), + (r'x', False, True), + ('', False, False), + ('', False, True), + (r'', True, True), + ], + ) def test_rnd_compute(self, templateProvided, escapeSpecial, useTemplateObject): template1 = TemplateGenerator(templateProvided, escapeSpecialChars=escapeSpecial) print(f"template [{templateProvided}]") @@ -133,81 +137,83 @@ def test_rnd_compute(self, templateProvided, escapeSpecial, useTemplateObject): for iy, bounds_value in enumerate(bounds): assert bounds_value == -1 or (rnds[iy] < bounds_value) - @pytest.mark.parametrize("templateProvided, escapeSpecial, useTemplateObject", - [ # (r'\w \w|\w \v. \w', False, False), - (r'\\w \\w|\\w a. \\w', False, False), - (r'\\w \\w|\\w a. \\w', False, True), - (r'\\w \\w|\\w a. \\w', True, True), - (r'\w \w|\w a. \w', False, False), - (r'\w.\w@\w.com', False, False), - (r'\n-\n', False, False), - (r'A', False, True), - (r'D', False, True), - (r'K', False, True), - (r'X', False, True), - (r'\W', False, True), - (r'\W', True, True), - (r'\\w A. \\w|\\w \\w', False, False), - (r'\\w \\w|\\w A. \\w', False, False), - (r'\\w \\w|\\w A. \\w', False, True), - (r'\\w \\w|\\w A. \\w', True, True), - (r'\\w \\w|\\w A. \\w|\w n n \w', False, False), - (r'\\w \\w|\\w K. \\w', False, False), - (r'\\w \\w|\\w K. \\w', False, True), - (r'\\w \\w|\\w K. \\w', True, True), - (r'\\w \\w|\\w X. \\w', False, False), - (r'\\w \\w|\\w X. \\w', False, True), - (r'\\w \\w|\\w X. \\w', True, True), - (r'\\w \\w|\\w a. \\w', False, False), - (r'\\w \\w|\\w a. \\w', False, True), - (r'\\w \\w|\\w a. \\w', True, True), - (r'\\w \\w|\\w k. \\w', False, False), - (r'\\w \\w|\\w k. \\w', False, True), - (r'\\w \\w|\\w k. \\w', True, True), - (r'\\w \\w|\\w x. \\w', False, False), - (r'\\w \\w|\\w x. \\w', False, True), - (r'\\w \\w|\\w x. \\w', True, True), - (r'\\w a. \\w', False, True), - (r'\\w a. \\w|\\w \\w', False, False), - (r'\\w k. \\w', False, True), - (r'\\w k. \\w|\\w \\w', False, False), - (r'\n', False, True), - (r'\n', True, True), - (r'\v', False, True), - (r'\v', True, True), - (r'\v|\v-\v', False, True), - (r'\v|\v-\v', True, True), - (r'short string|a much longer string which is bigger than short string', False, True), - (r'short string|a much longer string which is bigger than short string', True, True), - (r'\w A. \w', False, False), - (r'\w \a. \w', True, True), - (r'\w \k. \w', True, True), - (r'\w \n \w', True, True), - (r'\w \w|\w A. \w', False, False), - (r'\w \w|\w \A. \w', True, True), - (r'\w \w|\w \a. \w', True, True), - (r'\w \w|\w \w \w|\w \n \w|\w \w \w \w', True, True), - (r'\w aAdDkK \w', False, False), - (r'\w aAdDkKxX \n \N \w', False, False), - (r'\w', False, False), - (r'\w', False, True), - (r'\w', True, True), - (r'a', False, True), - (r'b', False, False), - (r'b', False, True), - (r'b', True, True), - (r'd', False, True), - (r'k', False, True), - (r'x', False, True), - ('', False, False), - ('', False, True), - (r'', True, True), - ('|', False, False), - ('|', False, True), - (r'|', True, True), - (r'\ww - not e\xpecting two wor\ds', False, False), - (r'\ww - not expecting two words', True, True) - ]) + @pytest.mark.parametrize( + "templateProvided, escapeSpecial, useTemplateObject", + [ # (r'\w \w|\w \v. \w', False, False), + (r'\\w \\w|\\w a. \\w', False, False), + (r'\\w \\w|\\w a. \\w', False, True), + (r'\\w \\w|\\w a. \\w', True, True), + (r'\w \w|\w a. \w', False, False), + (r'\w.\w@\w.com', False, False), + (r'\n-\n', False, False), + (r'A', False, True), + (r'D', False, True), + (r'K', False, True), + (r'X', False, True), + (r'\W', False, True), + (r'\W', True, True), + (r'\\w A. \\w|\\w \\w', False, False), + (r'\\w \\w|\\w A. \\w', False, False), + (r'\\w \\w|\\w A. \\w', False, True), + (r'\\w \\w|\\w A. \\w', True, True), + (r'\\w \\w|\\w A. \\w|\w n n \w', False, False), + (r'\\w \\w|\\w K. \\w', False, False), + (r'\\w \\w|\\w K. \\w', False, True), + (r'\\w \\w|\\w K. \\w', True, True), + (r'\\w \\w|\\w X. \\w', False, False), + (r'\\w \\w|\\w X. \\w', False, True), + (r'\\w \\w|\\w X. \\w', True, True), + (r'\\w \\w|\\w a. \\w', False, False), + (r'\\w \\w|\\w a. \\w', False, True), + (r'\\w \\w|\\w a. \\w', True, True), + (r'\\w \\w|\\w k. \\w', False, False), + (r'\\w \\w|\\w k. \\w', False, True), + (r'\\w \\w|\\w k. \\w', True, True), + (r'\\w \\w|\\w x. \\w', False, False), + (r'\\w \\w|\\w x. \\w', False, True), + (r'\\w \\w|\\w x. \\w', True, True), + (r'\\w a. \\w', False, True), + (r'\\w a. \\w|\\w \\w', False, False), + (r'\\w k. \\w', False, True), + (r'\\w k. \\w|\\w \\w', False, False), + (r'\n', False, True), + (r'\n', True, True), + (r'\v', False, True), + (r'\v', True, True), + (r'\v|\v-\v', False, True), + (r'\v|\v-\v', True, True), + (r'short string|a much longer string which is bigger than short string', False, True), + (r'short string|a much longer string which is bigger than short string', True, True), + (r'\w A. \w', False, False), + (r'\w \a. \w', True, True), + (r'\w \k. \w', True, True), + (r'\w \n \w', True, True), + (r'\w \w|\w A. \w', False, False), + (r'\w \w|\w \A. \w', True, True), + (r'\w \w|\w \a. \w', True, True), + (r'\w \w|\w \w \w|\w \n \w|\w \w \w \w', True, True), + (r'\w aAdDkK \w', False, False), + (r'\w aAdDkKxX \n \N \w', False, False), + (r'\w', False, False), + (r'\w', False, True), + (r'\w', True, True), + (r'a', False, True), + (r'b', False, False), + (r'b', False, True), + (r'b', True, True), + (r'd', False, True), + (r'k', False, True), + (r'x', False, True), + ('', False, False), + ('', False, True), + (r'', True, True), + ('|', False, False), + ('|', False, True), + (r'|', True, True), + (r'\ww - not e\xpecting two wor\ds', False, False), + (r'\ww - not expecting two words', True, True), + ], + ) def test_use_pandas(self, templateProvided, escapeSpecial, useTemplateObject): template1 = TemplateGenerator(templateProvided, escapeSpecialChars=escapeSpecial) @@ -246,14 +252,17 @@ def test_use_pandas(self, templateProvided, escapeSpecial, useTemplateObject): for i, result_value in enumerate(results): print(f"{i}: '{result_value}'") - @pytest.mark.parametrize("templateProvided, escapeSpecial, useTemplateObject", - [(r'\n', False, True), - (r'\n', True, True), - (r'\v', False, True), - (r'\v', True, True), - (r'\v|\v-\v', False, True), - (r'\v|\v-\v', True, True), - ]) + @pytest.mark.parametrize( + "templateProvided, escapeSpecial, useTemplateObject", + [ + (r'\n', False, True), + (r'\n', True, True), + (r'\v', False, True), + (r'\v', True, True), + (r'\v|\v-\v', False, True), + (r'\v|\v-\v', True, True), + ], + ) def test_sub_value1(self, templateProvided, escapeSpecial, useTemplateObject): template1 = TemplateGenerator(templateProvided, escapeSpecialChars=escapeSpecial) diff --git a/tests/test_topological_sort.py b/tests/test_topological_sort.py index b5e6a72f..3bec0b4e 100644 --- a/tests/test_topological_sort.py +++ b/tests/test_topological_sort.py @@ -13,7 +13,7 @@ def test_sort(self): ('id', []), ('code3a', []), ('_r_code1', []), - ('_r_code3', []) + ('_r_code3', []), ] output = list(dg.topologicalSort(src)) @@ -29,7 +29,7 @@ def test_sort2(self): ('id', []), ('code3a', []), ('_r_code1', ['id']), - ('_r_code3', ['id']) + ('_r_code3', ['id']), ] output = list(dg.topologicalSort(src, initial_columns=['id'], flatten=False)) diff --git a/tests/test_types.py b/tests/test_types.py index 30374d2e..42754bce 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -24,12 +24,13 @@ def test_basic_types(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=100000, partitions=id_partitions, verbose=True) - .withColumn("code1", IntegerType(), minValue=1, maxValue=20, step=1) - .withColumn("code2", LongType(), maxValue=1000, step=5) - .withColumn("code3", IntegerType(), minValue=100, maxValue=200, step=1, random=True) - .withColumn("xcode", StringType(), values=["a", "test", "value"], random=True) - .withColumn("rating", FloatType(), minValue=1.0, maxValue=5.0, step=0.00001, random=True) - .withColumn("drating", DoubleType(), minValue=1.0, maxValue=5.0, step=0.00001, random=True)) + .withColumn("code1", IntegerType(), minValue=1, maxValue=20, step=1) + .withColumn("code2", LongType(), maxValue=1000, step=5) + .withColumn("code3", IntegerType(), minValue=100, maxValue=200, step=1, random=True) + .withColumn("xcode", StringType(), values=["a", "test", "value"], random=True) + .withColumn("rating", FloatType(), minValue=1.0, maxValue=5.0, step=0.00001, random=True) + .withColumn("drating", DoubleType(), minValue=1.0, maxValue=5.0, step=0.00001, random=True) + ) df = testdata_defn.build().cache() df.printSchema() @@ -50,10 +51,11 @@ def test_reduced_range_types(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=num_rows, partitions=id_partitions, verbose=True) - .withColumn("basic_byte", ByteType()) - .withColumn("basic_short", ShortType()) - .withColumn("code1", ByteType(), minValue=1, maxValue=20, step=1) - .withColumn("code2", ShortType(), maxValue=1000, step=5)) + .withColumn("basic_byte", ByteType()) + .withColumn("basic_short", ShortType()) + .withColumn("code1", ByteType(), minValue=1, maxValue=20, step=1) + .withColumn("code2", ShortType(), maxValue=1000, step=5) + ) testdata_defn.build().createOrReplaceTempView("testdata") df = spark.sql("select * from testdata order by basic_short desc, basic_byte desc") @@ -61,10 +63,12 @@ def test_reduced_range_types(self): self.assertEqual(df.count(), num_rows) # check that range of code1 and code2 matches expectations - df_min_max = df.agg(F.min("code1").alias("min_code1"), - F.max("code1").alias("max_code1"), - F.min("code2").alias("min_code2"), - F.max("code2").alias("max_code2")) + df_min_max = df.agg( + F.min("code1").alias("min_code1"), + F.max("code1").alias("max_code1"), + F.min("code2").alias("min_code2"), + F.max("code2").alias("max_code2"), + ) limits = df_min_max.collect()[0] self.assertEqual(limits["min_code2"], 0) @@ -84,9 +88,10 @@ def test_out_of_range_types(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withColumn("basic_byte", ByteType()) - .withColumn("basic_short", ShortType()) - .withColumn("code1", ByteType(), minValue=1, maxValue=400, step=1)) + .withColumn("basic_byte", ByteType()) + .withColumn("basic_short", ShortType()) + .withColumn("code1", ByteType(), minValue=1, maxValue=400, step=1) + ) testdata_defn.build().createOrReplaceTempView("testdata") spark.sql("select * from testdata order by basic_short desc, basic_byte desc").show() @@ -95,9 +100,10 @@ def test_for_reverse_range(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withColumn("basic_byte", ByteType()) - .withColumn("basic_short", ShortType()) - .withColumn("code1", ByteType(), minValue=127, maxValue=1, step=-1)) + .withColumn("basic_byte", ByteType()) + .withColumn("basic_short", ShortType()) + .withColumn("code1", ByteType(), minValue=127, maxValue=1, step=-1) + ) df = testdata_defn.build().limit(130) data_row1 = df.collect() @@ -111,9 +117,9 @@ def test_for_reverse_range2(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withColumn("basic_byte", ByteType()) - .withColumn("basic_short", ShortType()) - .withColumn("code1", ByteType(), minValue=127, maxValue=1, step=-1) + .withColumn("basic_byte", ByteType()) + .withColumn("basic_short", ShortType()) + .withColumn("code1", ByteType(), minValue=127, maxValue=1, step=-1) ) df = testdata_defn.build().limit(130) @@ -125,11 +131,9 @@ def test_for_values_with_multi_column_dependencies(self): code_values = ["aa", "bb", "cc", "dd", "ee", "ff"] testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withColumn("basic_byte", ByteType()) - .withColumn("basic_short", ShortType()) - .withColumn("code1", StringType(), - values=code_values, - baseColumn=["basic_byte", "basic_short"]) + .withColumn("basic_byte", ByteType()) + .withColumn("basic_short", ShortType()) + .withColumn("code1", StringType(), values=code_values, baseColumn=["basic_byte", "basic_short"]) ) df = testdata_defn.build().where("code1 is null") @@ -152,11 +156,9 @@ def test_for_values_with_single_column_dependencies(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withColumn("basic_byte", ByteType()) - .withColumn("basic_short", ShortType()) - .withColumn("code1", StringType(), - values=["aa", "bb", "cc", "dd", "ee", "ff"], - baseColumn=["basic_byte"]) + .withColumn("basic_byte", ByteType()) + .withColumn("basic_short", ShortType()) + .withColumn("code1", StringType(), values=["aa", "bb", "cc", "dd", "ee", "ff"], baseColumn=["basic_byte"]) ) df = testdata_defn.build().where("code1 is null") self.assertEqual(df.count(), 0) @@ -166,12 +168,10 @@ def test_for_values_with_single_column_dependencies2(self): rows_wanted = 1000000 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=rows_wanted, partitions=id_partitions, verbose=True) - .withIdOutput() - .withColumn("basic_byte", ByteType()) - .withColumn("basic_short", ShortType()) - .withColumn("code1", StringType(), - values=["aa", "bb", "cc", "dd", "ee", "ff"], - baseColumn=["basic_byte"]) + .withIdOutput() + .withColumn("basic_byte", ByteType()) + .withColumn("basic_short", ShortType()) + .withColumn("code1", StringType(), values=["aa", "bb", "cc", "dd", "ee", "ff"], baseColumn=["basic_byte"]) ) df = testdata_defn.build() # df.show() @@ -182,10 +182,9 @@ def test_for_values_with_default_column_dependencies(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withColumn("basic_byte", ByteType()) - .withColumn("basic_short", ShortType()) - .withColumn("code1", StringType(), - values=["aa", "bb", "cc", "dd", "ee", "ff"]) + .withColumn("basic_byte", ByteType()) + .withColumn("basic_short", ShortType()) + .withColumn("code1", StringType(), values=["aa", "bb", "cc", "dd", "ee", "ff"]) ) df = testdata_defn.build().where("code1 is null") self.assertEqual(df.count(), 0) @@ -195,11 +194,9 @@ def test_for_weighted_values_with_default_column_dependencies(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withColumn("basic_byte", ByteType()) - .withColumn("basic_short", ShortType()) - .withColumn("code1", StringType(), - values=["aa", "bb", "cc", "dd", "ee", "ff"], - weights=[1, 2, 3, 4, 5, 6]) + .withColumn("basic_byte", ByteType()) + .withColumn("basic_short", ShortType()) + .withColumn("code1", StringType(), values=["aa", "bb", "cc", "dd", "ee", "ff"], weights=[1, 2, 3, 4, 5, 6]) ) df = testdata_defn.build().where("code1 is null") self.assertEqual(df.count(), 0) @@ -208,12 +205,10 @@ def test_for_weighted_values_with_default_column_dependencies2(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withIdOutput() - .withColumn("basic_byte", ByteType()) - .withColumn("basic_short", ShortType()) - .withColumn("code1", StringType(), - values=["aa", "bb", "cc", "dd", "ee", "ff"], - weights=[1, 2, 3, 4, 5, 6]) + .withIdOutput() + .withColumn("basic_byte", ByteType()) + .withColumn("basic_short", ShortType()) + .withColumn("code1", StringType(), values=["aa", "bb", "cc", "dd", "ee", "ff"], weights=[1, 2, 3, 4, 5, 6]) ) df = testdata_defn.build().where("code1 is null") df.show() @@ -223,10 +218,10 @@ def test_out_of_range_types2(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withColumn("basic_byte", ByteType()) - .withColumn("basic_short", ShortType()) - - .withColumn("code2", ShortType(), maxValue=80000, step=5)) + .withColumn("basic_byte", ByteType()) + .withColumn("basic_short", ShortType()) + .withColumn("code2", ShortType(), maxValue=80000, step=5) + ) testdata_defn.build().createOrReplaceTempView("testdata") spark.sql("select * from testdata order by basic_short desc, basic_byte desc").show() @@ -235,10 +230,10 @@ def test_short_types1(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withColumn("bb", ByteType(), unique_values=100) - .withColumn("basic_short", ShortType()) - - .withColumn("code2", ShortType(), maxValue=10000, step=5)) + .withColumn("bb", ByteType(), unique_values=100) + .withColumn("basic_short", ShortType()) + .withColumn("code2", ShortType(), maxValue=10000, step=5) + ) testdata_defn.build().createOrReplaceTempView("testdata") data_row = spark.sql("select min(bb) as min_bb, max(bb) as max_bb from testdata ").limit(1).collect() @@ -249,10 +244,10 @@ def test_short_types1a(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withColumn("bb", ByteType(), minValue=35, maxValue=72) - .withColumn("basic_short", ShortType()) - - .withColumn("code2", ShortType(), maxValue=10000, step=5)) + .withColumn("bb", ByteType(), minValue=35, maxValue=72) + .withColumn("basic_short", ShortType()) + .withColumn("code2", ShortType(), maxValue=10000, step=5) + ) testdata_defn.build().createOrReplaceTempView("testdata") data_row = spark.sql("select min(bb) as min_bb, max(bb) as max_bb from testdata ").limit(1).collect() @@ -265,9 +260,10 @@ def test_short_types1b(self): # result should be the same whether using `minValue` or `min` as options testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withColumn("bb", ByteType(), minValue=35, maxValue=72) - .withColumn("basic_short", ShortType()) - .withColumn("code2", ShortType(), maxValue=10000, step=5)) + .withColumn("bb", ByteType(), minValue=35, maxValue=72) + .withColumn("basic_short", ShortType()) + .withColumn("code2", ShortType(), maxValue=10000, step=5) + ) testdata_defn.build().createOrReplaceTempView("testdata") data_row = spark.sql("select min(bb) as min_bb, max(bb) as max_bb from testdata ").limit(1).collect() @@ -278,10 +274,10 @@ def test_short_types2(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withColumn("bb", ByteType(), unique_values=100) - .withColumn("basic_short", ShortType()) - - .withColumn("code2", ShortType(), maxValue=4000, step=5)) + .withColumn("bb", ByteType(), unique_values=100) + .withColumn("basic_short", ShortType()) + .withColumn("code2", ShortType(), maxValue=4000, step=5) + ) testdata_defn.build().show() @@ -289,12 +285,13 @@ def test_decimal(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withIdOutput() - .withColumn("code1", DecimalType(10, 3)) - .withColumn("code2", DecimalType(10, 5)) - .withColumn("code3", DecimalType(10, 5), minValue=1.0, maxValue=1000.0) - .withColumn("code4", DecimalType(10, 5), random=True, continuous=True) - .withColumn("code5", DecimalType(10, 5), minValue=1.0, maxValue=1000.0, random=True, continuous=True)) + .withIdOutput() + .withColumn("code1", DecimalType(10, 3)) + .withColumn("code2", DecimalType(10, 5)) + .withColumn("code3", DecimalType(10, 5), minValue=1.0, maxValue=1000.0) + .withColumn("code4", DecimalType(10, 5), random=True, continuous=True) + .withColumn("code5", DecimalType(10, 5), minValue=1.0, maxValue=1000.0, random=True, continuous=True) + ) df = testdata_defn.build() df.show() @@ -303,12 +300,13 @@ def test_decimal2(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - .withIdOutput() - .withColumn("code1", DecimalType(10, 3)) - .withColumn("code2", DecimalType(10, 5)) - .withColumn("code3", DecimalType(10, 5), minValue=1.0, maxValue=1000.0) - .withColumn("code4", DecimalType(10, 5), random=True, continuous=True) - .withColumn("code5", DecimalType(10, 5), minValue=1.0, maxValue=1000.0, random=True, continuous=True)) + .withIdOutput() + .withColumn("code1", DecimalType(10, 3)) + .withColumn("code2", DecimalType(10, 5)) + .withColumn("code3", DecimalType(10, 5), minValue=1.0, maxValue=1000.0) + .withColumn("code4", DecimalType(10, 5), random=True, continuous=True) + .withColumn("code5", DecimalType(10, 5), minValue=1.0, maxValue=1000.0, random=True, continuous=True) + ) testdata_defn.build().createOrReplaceTempView("testdata") @@ -316,18 +314,19 @@ def test_decimal_min_and_max_values(self): id_partitions = 4 testdata_defn = ( dg.DataGenerator(name="basic_dataset", rows=1000000, partitions=id_partitions, verbose=True) - - .withIdOutput() - .withColumn("group1", IntegerType(), expr="1") - .withColumn("code1", DecimalType(10, 3)) - .withColumn("code2", DecimalType(10, 5)) - .withColumn("code3", DecimalType(10, 5), minValue=1.0, maxValue=1000.0) - .withColumn("code4", DecimalType(10, 5), random=True, continuous=True) - .withColumn("code5", DecimalType(10, 5), minValue=2.0, maxValue=2000.0, random=True, continuous=True)) + .withIdOutput() + .withColumn("group1", IntegerType(), expr="1") + .withColumn("code1", DecimalType(10, 3)) + .withColumn("code2", DecimalType(10, 5)) + .withColumn("code3", DecimalType(10, 5), minValue=1.0, maxValue=1000.0) + .withColumn("code4", DecimalType(10, 5), random=True, continuous=True) + .withColumn("code5", DecimalType(10, 5), minValue=2.0, maxValue=2000.0, random=True, continuous=True) + ) testdata_defn.build().createOrReplaceTempView("testdata") - df2 = spark.sql("""select min(code1) as min1, max(code1) as max1, + df2 = spark.sql( + """select min(code1) as min1, max(code1) as max1, min(code2) as min2, max(code2) as max2 , min(code3) as min3, @@ -336,16 +335,27 @@ def test_decimal_min_and_max_values(self): max(code4) as max4, min(code5) as min5, max(code5) as max5 - from testdata group by group1 """) + from testdata group by group1 """ + ) results = df2.collect()[0] print(results) - min1, min2, min3, min4, min5 = results['max1'], results['min2'], results['min3'], results['min4'], results[ - 'min5'] - max1, max2, max3, max4, max5 = results['max1'], results['max2'], results['max3'], results['max4'], results[ - 'max5'] + min1, min2, min3, min4, min5 = ( + results['max1'], + results['min2'], + results['min3'], + results['min4'], + results['min5'], + ) + max1, max2, max3, max4, max5 = ( + results['max1'], + results['max2'], + results['max3'], + results['max4'], + results['max5'], + ) self.assertGreaterEqual(min1, 0.0) self.assertGreaterEqual(min2, 0.0) diff --git a/tests/test_utils.py b/tests/test_utils.py index 7975408e..9c89c0e8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,9 +3,20 @@ import pytest -from dbldatagen import ensure, mkBoundsList, coalesce_values, deprecated, SparkSingleton, \ - parse_time_interval, DataGenError, strip_margins, split_list_matching_condition, topologicalSort, \ - json_value_from_path, system_time_millis +from dbldatagen import ( + ensure, + mkBoundsList, + coalesce_values, + deprecated, + SparkSingleton, + parse_time_interval, + DataGenError, + strip_margins, + split_list_matching_condition, + topologicalSort, + json_value_from_path, + system_time_millis, +) spark = SparkSingleton.getLocalInstance("unit tests") @@ -25,44 +36,51 @@ def test_ensure(self): with pytest.raises(Exception): ensure(1 == 2, "Expected error") # pylint: disable=comparison-of-constants - @pytest.mark.parametrize("value,defaultValues", - [(None, 1), - (None, [1, 2]), - (5, [1, 2]), - (5, 1), - ([1, 2], [3, 4]), - ]) + @pytest.mark.parametrize( + "value,defaultValues", + [ + (None, 1), + (None, [1, 2]), + (5, [1, 2]), + (5, 1), + ([1, 2], [3, 4]), + ], + ) def test_mkBoundsList1(self, value, defaultValues): - """ Test utils mkBoundsList""" + """Test utils mkBoundsList""" test = mkBoundsList(value, defaultValues) assert len(test) == 2 - @pytest.mark.parametrize("test_input,expected", - [ - ([None, 1], 1), - ([2, 1], 2), - ([3, None, 1], 3), - ([None, None, None], None), - ]) + @pytest.mark.parametrize( + "test_input,expected", + [ + ([None, 1], 1), + ([2, 1], 2), + ([3, None, 1], 3), + ([None, None, None], None), + ], + ) def test_coalesce(self, test_input, expected): - """ Test utils coalesce function""" + """Test utils coalesce function""" result = coalesce_values(*test_input) assert result == expected - @pytest.mark.parametrize("test_input,expected", - [ - ("1 hours, minutes = 2", timedelta(hours=1, minutes=2)), - ("4 days, 1 hours, 2 minutes", timedelta(days=4, hours=1, minutes=2)), - ("days=4, hours=1, minutes=2", timedelta(days=4, hours=1, minutes=2)), - ("1 hours, 2 seconds", timedelta(hours=1, seconds=2)), - ("1 hours, 2 minutes", timedelta(hours=1, minutes=2)), - ("1 hours", timedelta(hours=1)), - ("1 hour", timedelta(hours=1)), - ("1 hour, 1 second", timedelta(hours=1, seconds=1)), - ("1 hour, 10 milliseconds", timedelta(hours=1, milliseconds=10)), - ("1 hour, 10 microseconds", timedelta(hours=1, microseconds=10)), - ("1 year, 4 weeks", timedelta(weeks=56)) - ]) + @pytest.mark.parametrize( + "test_input,expected", + [ + ("1 hours, minutes = 2", timedelta(hours=1, minutes=2)), + ("4 days, 1 hours, 2 minutes", timedelta(days=4, hours=1, minutes=2)), + ("days=4, hours=1, minutes=2", timedelta(days=4, hours=1, minutes=2)), + ("1 hours, 2 seconds", timedelta(hours=1, seconds=2)), + ("1 hours, 2 minutes", timedelta(hours=1, minutes=2)), + ("1 hours", timedelta(hours=1)), + ("1 hour", timedelta(hours=1)), + ("1 hour, 1 second", timedelta(hours=1, seconds=1)), + ("1 hour, 10 milliseconds", timedelta(hours=1, milliseconds=10)), + ("1 hour, 10 microseconds", timedelta(hours=1, microseconds=10)), + ("1 year, 4 weeks", timedelta(weeks=56)), + ], + ) def testParseTimeInterval2b(self, test_input, expected): interval = parse_time_interval(test_input) assert expected == interval @@ -78,51 +96,101 @@ def testDatagenExceptionObject(self): assert type(str(testException)) is str self.logger.info(str(testException)) - @pytest.mark.parametrize("inputText,expectedText", - [("""one + @pytest.mark.parametrize( + "inputText,expectedText", + [ + ( + """one |two |three""", - "one\ntwo\nthree"), - ("", ""), - ("one\ntwo", "one\ntwo"), - (" one\ntwo", " one\ntwo"), - (" |one\ntwo", "one\ntwo"), - ]) + "one\ntwo\nthree", + ), + ("", ""), + ("one\ntwo", "one\ntwo"), + (" one\ntwo", " one\ntwo"), + (" |one\ntwo", "one\ntwo"), + ], + ) def test_strip_margins(self, inputText, expectedText): output = strip_margins(inputText, '|') assert output == expectedText - @pytest.mark.parametrize("lstData,matchFn, expectedData", - [ - (['id', 'city_name', 'id', 'city_id', 'city_pop', 'id', 'city_id', - 'city_pop', 'city_id', 'city_pop', 'id'], - lambda el: el == 'id', - [['id'], ['city_name'], ['id'], ['city_id', 'city_pop'], ['id'], - ['city_id', 'city_pop', 'city_id', 'city_pop'], ['id']] - ), - (['id', 'city_name', 'id', 'city_id', 'city_pop', 'id', 'city_id', - 'city_pop2', 'city_id', 'city_pop', 'id'], - lambda el: el in ['id', 'city_pop'], - [['id'], ['city_name'], ['id'], ['city_id'], ['city_pop'], ['id'], - ['city_id', 'city_pop2', 'city_id'], ['city_pop'], ['id']] - ), - ([], lambda el: el == 'id', []), - (['id'], lambda el: el == 'id', [['id']]), - (['id', 'id'], lambda el: el == 'id', [['id'], ['id']]), - (['no', 'matches'], lambda el: el == 'id', [['no', 'matches']]) - ]) + @pytest.mark.parametrize( + "lstData,matchFn, expectedData", + [ + ( + [ + 'id', + 'city_name', + 'id', + 'city_id', + 'city_pop', + 'id', + 'city_id', + 'city_pop', + 'city_id', + 'city_pop', + 'id', + ], + lambda el: el == 'id', + [ + ['id'], + ['city_name'], + ['id'], + ['city_id', 'city_pop'], + ['id'], + ['city_id', 'city_pop', 'city_id', 'city_pop'], + ['id'], + ], + ), + ( + [ + 'id', + 'city_name', + 'id', + 'city_id', + 'city_pop', + 'id', + 'city_id', + 'city_pop2', + 'city_id', + 'city_pop', + 'id', + ], + lambda el: el in ['id', 'city_pop'], + [ + ['id'], + ['city_name'], + ['id'], + ['city_id'], + ['city_pop'], + ['id'], + ['city_id', 'city_pop2', 'city_id'], + ['city_pop'], + ['id'], + ], + ), + ([], lambda el: el == 'id', []), + (['id'], lambda el: el == 'id', [['id']]), + (['id', 'id'], lambda el: el == 'id', [['id'], ['id']]), + (['no', 'matches'], lambda el: el == 'id', [['no', 'matches']]), + ], + ) def testSplitListOnCondition(self, lstData, matchFn, expectedData): results = split_list_matching_condition(lstData, matchFn) print(results) assert results == expectedData - @pytest.mark.parametrize("dependencies, raisesError", - [([], False), - ([("id", []), ("name", ["id"]), ("name2", ["name"])], False), - ([("id", []), ("name", ["id"]), ("name2", ["name3"]), ("name3", ["name2"])], True), - ]) + @pytest.mark.parametrize( + "dependencies, raisesError", + [ + ([], False), + ([("id", []), ("name", ["id"]), ("name2", ["name"])], False), + ([("id", []), ("name", ["id"]), ("name2", ["name3"]), ("name3", ["name2"])], True), + ], + ) def test_topological_sort(self, dependencies, raisesError): raised_exception = False try: @@ -134,12 +202,15 @@ def test_topological_sort(self, dependencies, raisesError): assert raised_exception == raisesError - @pytest.mark.parametrize("path,jsonData, defaultValue, expectedValue", - [("a", """{"a":1,"b":2,"c":[1,2,3]}""", None, 1), - ("b", """{"a":1,"b":2,"c":[1,2,3]}""", None, 2), - ("d", """{"a":1,"b":2,"c":[1,2,3]}""", 42, 42), - ("c[2]", """{"a":1,"b":2,"c":[1,2,3]}""", None, 3), - ]) + @pytest.mark.parametrize( + "path,jsonData, defaultValue, expectedValue", + [ + ("a", """{"a":1,"b":2,"c":[1,2,3]}""", None, 1), + ("b", """{"a":1,"b":2,"c":[1,2,3]}""", None, 2), + ("d", """{"a":1,"b":2,"c":[1,2,3]}""", 42, 42), + ("c[2]", """{"a":1,"b":2,"c":[1,2,3]}""", None, 3), + ], + ) def test_json_value_from_path(self, path, jsonData, defaultValue, expectedValue): results = json_value_from_path(path, jsonData, defaultValue) diff --git a/tests/test_weights.py b/tests/test_weights.py index e8f3e3e2..940d12ae 100644 --- a/tests/test_weights.py +++ b/tests/test_weights.py @@ -20,14 +20,13 @@ def setUpClass(cls): # will have implied column `id` for ordinal of row cls.testdata_generator = ( dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=cls.rows, partitions=4) - .withIdOutput() # id column will be emitted in the output - .withColumn("code1", "integer", minValue=1, maxValue=20, step=1) - .withColumn("code4", "integer", minValue=1, maxValue=40, step=1, random=True) - .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, - prefix='status', random=True) - .withColumn("tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], - weights=desired_weights, - random=True) + .withIdOutput() # id column will be emitted in the output + .withColumn("code1", "integer", minValue=1, maxValue=20, step=1) + .withColumn("code4", "integer", minValue=1, maxValue=40, step=1, random=True) + .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, prefix='status', random=True) + .withColumn( + "tech", "string", values=["GSM", "LTE", "UMTS", "UNKNOWN"], weights=desired_weights, random=True + ) ) cls.testdata_generator.build().cache().createOrReplaceTempView("testdata") @@ -52,11 +51,14 @@ def get_observed_weights(cls, df, column, values): assert col is not None assert values is not None - observed_weights = (df.cube(column).count() - .withColumnRenamed(column, "value") - .withColumnRenamed("count", "rc") - .where("value is not null") - .collect()) + observed_weights = ( + df.cube(column) + .count() + .withColumnRenamed(column, "value") + .withColumnRenamed("count", "rc") + .where("value is not null") + .collect() + ) print(observed_weights) @@ -79,19 +81,16 @@ def assertPercentagesEqual(self, percentages, desired_percentages, target_delta= self.assertAlmostEqual(x, y, delta=float(x) * target_delta) def test_get_observed_weights(self): - alpha_desired_weights = [9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, 9 - ] + alpha_desired_weights = [9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9] alpha_list = list("abcdefghijklmnopqrstuvwxyz") - dsAlpha = (dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4) - .withIdOutput() # id column will be emitted in the output - .withColumn("pk1", "int", unique_values=100) - .withColumn("alpha", "string", values=alpha_list, baseColumn="pk1", - weights=alpha_desired_weights, random=True) - ) + dsAlpha = ( + dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4) + .withIdOutput() # id column will be emitted in the output + .withColumn("pk1", "int", unique_values=100) + .withColumn( + "alpha", "string", values=alpha_list, baseColumn="pk1", weights=alpha_desired_weights, random=True + ) + ) dfAlpha = dsAlpha.build().cache() values = dsAlpha['alpha'].values @@ -118,7 +117,8 @@ def test_basic2(self): def test_generate_values(self): df_values = spark.sql( - "select * from (select tech, count(tech) as rc from testdata group by tech ) a order by tech").collect() + "select * from (select tech, count(tech) as rc from testdata group by tech ) a order by tech" + ).collect() values = [x.tech for x in df_values] print("row values:", values) total_count = sum([x.rc for x in df_values]) # pylint: disable=consider-using-generator @@ -130,19 +130,13 @@ def test_generate_values(self): self.assertPercentagesEqual(percentages, desired_percentages) def test_weighted_distribution(self): - alpha_desired_weights = [9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, 9 - ] + alpha_desired_weights = [9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9] alpha_list = list("abcdefghijklmnopqrstuvwxyz") - dsAlpha = (dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4) - .withIdOutput() # id column will be emitted in the output - .withColumn("alpha", "string", values=alpha_list, - weights=alpha_desired_weights, - random=True) - ) + dsAlpha = ( + dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4) + .withIdOutput() # id column will be emitted in the output + .withColumn("alpha", "string", values=alpha_list, weights=alpha_desired_weights, random=True) + ) dfAlpha = dsAlpha.build().cache() observed_weights = self.get_observed_weights(dfAlpha, 'alpha', dsAlpha['alpha'].values) @@ -152,21 +146,16 @@ def test_weighted_distribution(self): self.assertPercentagesEqual(percentages, desired_percentages) def test_weighted_distribution_nr(self): - """ Test distribution of values with weights for non random values""" - alpha_desired_weights = [9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, 9 - ] + """Test distribution of values with weights for non random values""" + alpha_desired_weights = [9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9] alpha_list = list("abcdefghijklmnopqrstuvwxyz") # dont use seed value as non random fields should be repeatable - dsAlpha = (dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4) - .withIdOutput() # id column will be emitted in the output - .withColumn("alpha", "string", values=alpha_list, - weights=alpha_desired_weights) - ) + dsAlpha = ( + dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4) + .withIdOutput() # id column will be emitted in the output + .withColumn("alpha", "string", values=alpha_list, weights=alpha_desired_weights) + ) dfAlpha = dsAlpha.build().cache() observed_weights = self.get_observed_weights(dfAlpha, 'alpha', dsAlpha['alpha'].values) @@ -177,22 +166,17 @@ def test_weighted_distribution_nr(self): # @unittest.skip("not yet finalized") def test_weighted_distribution_nr2(self): - alpha_desired_weights = [9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, 9 - ] + alpha_desired_weights = [9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9] alpha_list = list("abcdefghijklmnopqrstuvwxyz") # dont use seed value as non random fields should be repeatable - dsAlpha = (dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4) - .withIdOutput() # id column will be emitted in the output - .withColumn("pk1", "int", unique_values=500) - .withColumn("pk2", "int", unique_values=500) - .withColumn("alpha", "string", values=alpha_list, baseColumn="pk1", - weights=alpha_desired_weights) - ) + dsAlpha = ( + dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4) + .withIdOutput() # id column will be emitted in the output + .withColumn("pk1", "int", unique_values=500) + .withColumn("pk2", "int", unique_values=500) + .withColumn("alpha", "string", values=alpha_list, baseColumn="pk1", weights=alpha_desired_weights) + ) dfAlpha = dsAlpha.build().cache() observed_weights = self.get_observed_weights(dfAlpha, 'alpha', dsAlpha['alpha'].values) @@ -204,11 +188,13 @@ def test_weighted_distribution_nr2(self): # for columns with non random values and a single base dependency `pk1` # each combination of pk1 and alpha should be the same - df_counts = (dfAlpha.cube("pk1", "alpha") - .count() - .where("pk1 is not null and alpha is not null") - .orderBy("pk1").withColumnRenamed("count", "rc") - ) + df_counts = ( + dfAlpha.cube("pk1", "alpha") + .count() + .where("pk1 is not null and alpha is not null") + .orderBy("pk1") + .withColumnRenamed("count", "rc") + ) # get counts for each primary key from the cube # they should be 1 for each primary key @@ -216,20 +202,17 @@ def test_weighted_distribution_nr2(self): self.assertEqual(df_counts_by_key.where("rc > 1").count(), 0) def test_weighted_distribution2(self): - alpha_desired_weights = [9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, 9 - ] + alpha_desired_weights = [9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9] alpha_list = list("abcdefghijklmnopqrstuvwxyz") - dsAlpha = (dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4) - .withIdOutput() # id column will be emitted in the output - .withColumn("pk1", "int", unique_values=500) - .withColumn("pk2", "int", unique_values=500) - .withColumn("alpha", "string", values=alpha_list, baseColumn="pk1", - weights=alpha_desired_weights, random=True) - ) + dsAlpha = ( + dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4) + .withIdOutput() # id column will be emitted in the output + .withColumn("pk1", "int", unique_values=500) + .withColumn("pk2", "int", unique_values=500) + .withColumn( + "alpha", "string", values=alpha_list, baseColumn="pk1", weights=alpha_desired_weights, random=True + ) + ) dfAlpha = dsAlpha.build().cache() observed_weights = self.get_observed_weights(dfAlpha, 'alpha', dsAlpha['alpha'].values) @@ -239,20 +222,22 @@ def test_weighted_distribution2(self): self.assertPercentagesEqual(percentages, desired_percentages) def test_weighted_distribution3(self): - alpha_desired_weights = [9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, 9 - ] + alpha_desired_weights = [9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9] alpha_list = list("abcdefghijklmnopqrstuvwxyz") - dsAlpha = (dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4) - .withIdOutput() # id column will be emitted in the output - .withColumn("pk1", "int", unique_values=500) - .withColumn("pk2", "int", unique_values=500) - .withColumn("alpha", "string", values=alpha_list, baseColumn=["pk1", "pk2"], - weights=alpha_desired_weights, random=True) - ) + dsAlpha = ( + dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4) + .withIdOutput() # id column will be emitted in the output + .withColumn("pk1", "int", unique_values=500) + .withColumn("pk2", "int", unique_values=500) + .withColumn( + "alpha", + "string", + values=alpha_list, + baseColumn=["pk1", "pk2"], + weights=alpha_desired_weights, + random=True, + ) + ) dfAlpha = dsAlpha.build().cache() observed_weights = self.get_observed_weights(dfAlpha, 'alpha', dsAlpha['alpha'].values) @@ -262,22 +247,17 @@ def test_weighted_distribution3(self): self.assertPercentagesEqual(percentages, desired_percentages) def test_weighted_distribution_nr3(self): - alpha_desired_weights = [9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, 9 - ] + alpha_desired_weights = [9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9] alpha_list = list("abcdefghijklmnopqrstuvwxyz") # dont use seed value as non random fields should be repeatable - dsAlpha = (dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4, debug=True) - .withIdOutput() # id column will be emitted in the output - .withColumn("pk1", "int", unique_values=500) - .withColumn("pk2", "int", unique_values=500) - .withColumn("alpha", "string", values=alpha_list, baseColumn=["pk1", "pk2"], - weights=alpha_desired_weights) - ) + dsAlpha = ( + dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4, debug=True) + .withIdOutput() # id column will be emitted in the output + .withColumn("pk1", "int", unique_values=500) + .withColumn("pk2", "int", unique_values=500) + .withColumn("alpha", "string", values=alpha_list, baseColumn=["pk1", "pk2"], weights=alpha_desired_weights) + ) dfAlpha = dsAlpha.build().cache() observed_weights = self.get_observed_weights(dfAlpha, 'alpha', dsAlpha['alpha'].values) @@ -291,11 +271,13 @@ def test_weighted_distribution_nr3(self): # for columns with non random values and base dependency on `pk1` and `pk2` # each combination of pk1, pk2 and alpha should be the same - df_counts = (dfAlpha.cube("pk1", "pk2", "alpha") - .count() - .where("pk1 is not null and alpha is not null and pk2 is not null") - .orderBy("pk1", "pk2").withColumnRenamed("count", "rc") - ) + df_counts = ( + dfAlpha.cube("pk1", "pk2", "alpha") + .count() + .where("pk1 is not null and alpha is not null and pk2 is not null") + .orderBy("pk1", "pk2") + .withColumnRenamed("count", "rc") + ) # get counts for each primary key from the cube # they should be 1 for each primary key @@ -305,12 +287,11 @@ def test_weighted_distribution_nr3(self): def test_weighted_distribution_int(self): num_desired_weights = [9, 4, 1, 10, 5] num_list = [1, 2, 3, 4, 5] - dsInt1 = (dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4) - .withIdOutput() # id column will be emitted in the output - .withColumn("code", "integer", values=num_list, - weights=num_desired_weights, - random=True) - ) + dsInt1 = ( + dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4) + .withIdOutput() # id column will be emitted in the output + .withColumn("code", "integer", values=num_list, weights=num_desired_weights, random=True) + ) dfInt1 = dsInt1.build().cache() observed_weights = self.get_observed_weights(dfInt1, 'code', dsInt1['code'].values) @@ -321,16 +302,16 @@ def test_weighted_distribution_int(self): self.assertPercentagesEqual(percentages, desired_percentages) def test_weighted_nr_int(self): - """ Test distribution of non-random values where field is a integer""" + """Test distribution of non-random values where field is a integer""" num_desired_weights = [9, 4, 1, 10, 5] num_list = [1, 2, 3, 4, 5] # dont use seed value as non random fields should be repeatable - dsInt1 = (dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4, debug=True) - .withIdOutput() # id column will be emitted in the output - .withColumn("code", "integer", values=num_list, - weights=num_desired_weights, base_column_type="hash") - ) + dsInt1 = ( + dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 10000, partitions=4, debug=True) + .withIdOutput() # id column will be emitted in the output + .withColumn("code", "integer", values=num_list, weights=num_desired_weights, base_column_type="hash") + ) dfInt1 = dsInt1.build().cache() observed_weights = self.get_observed_weights(dfInt1, 'code', dsInt1['code'].values) @@ -342,23 +323,15 @@ def test_weighted_nr_int(self): # @unittest.skip("not yet finalized") def test_weighted_repeatable_non_random(self): - alpha_desired_weights = [9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, 9 - ] + alpha_desired_weights = [9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9] alpha_list = list("abcdefghijklmnopqrstuvwxyz") # dont use seed value as non random fields should be repeatable - dsAlpha = (dg.DataGenerator(sparkSession=spark, - name="test_dataset1", - rows=26 * 1000, - partitions=4) - .withIdOutput() # id column will be emitted in the output - .withColumn("alpha", "string", values=alpha_list, - weights=alpha_desired_weights) - ) + dsAlpha = ( + dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 1000, partitions=4) + .withIdOutput() # id column will be emitted in the output + .withColumn("alpha", "string", values=alpha_list, weights=alpha_desired_weights) + ) dfAlpha = dsAlpha.build().limit(100).cache() values1 = dfAlpha.collect() @@ -370,24 +343,15 @@ def test_weighted_repeatable_non_random(self): self.assertEqual(values1, values2) def test_weighted_repeatable_random(self): - alpha_desired_weights = [9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, - 9, 4, 1, 10, 5, 9 - ] + alpha_desired_weights = [9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9, 4, 1, 10, 5, 9] alpha_list = list("abcdefghijklmnopqrstuvwxyz") # use seed for random repeatability - dsAlpha = (dg.DataGenerator(sparkSession=spark, - name="test_dataset1", - rows=26 * 1000, - partitions=4, - seed=43) - .withIdOutput() # id column will be emitted in the output - .withColumn("alpha", "string", values=alpha_list, - weights=alpha_desired_weights, random=True) - ) + dsAlpha = ( + dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=26 * 1000, partitions=4, seed=43) + .withIdOutput() # id column will be emitted in the output + .withColumn("alpha", "string", values=alpha_list, weights=alpha_desired_weights, random=True) + ) dfAlpha = dsAlpha.build().limit(100).cache() values1 = dfAlpha.collect() diff --git a/tutorial/1-Introduction.py b/tutorial/1-Introduction.py index 0764394a..74167421 100644 --- a/tutorial/1-Introduction.py +++ b/tutorial/1-Introduction.py @@ -4,7 +4,7 @@ # COMMAND ---------- # MAGIC %md ### First Steps ### -# MAGIC +# MAGIC # MAGIC You will need to import the data generator library in order to use it. # MAGIC # MAGIC Within a notebook, you can install the package from PyPi using `pip install` to install the @@ -24,9 +24,9 @@ # COMMAND ---------- # MAGIC %md ### Brief Introduction ### -# MAGIC +# MAGIC # MAGIC You can use the data generator to -# MAGIC +# MAGIC # MAGIC * Generate Pyspark data frames from individual column declarations and schema definitions # MAGIC * Augment the schema and column definitions with directives as to how data should be generated # MAGIC * specify weighting of values @@ -37,19 +37,19 @@ # MAGIC * specify arbitrary SQL expressions # MAGIC * customize generation of text, timestamps, date and other data # MAGIC * All of the above can be done within the Databricks notebook environment -# MAGIC +# MAGIC # MAGIC See the help information in the [repository documentation files](https://github.com/databrickslabs/dbldatagen/blob/master/docs/source/APIDOCS.md) and in the [online help Github Pages](https://databrickslabs.github.io/dbldatagen/) for more details. -# MAGIC +# MAGIC # MAGIC The resulting data frames can be saved, used as a source for other operations, converted to view for # MAGIC consumption from Scala and other languages / environments. -# MAGIC +# MAGIC # MAGIC As the resulting dataframe is a full defined PySpark dataframe, you can supplement resulting data frame with # MAGIC regular spark code to address scenarios not covered by the library. # COMMAND ---------- # MAGIC %md ### Using the Data Generator ### -# MAGIC +# MAGIC # MAGIC lets look at several basic scenarios: # MAGIC * generating a test data set from manually specified columns # MAGIC * generating a test data set from a schema definition @@ -64,14 +64,15 @@ from pyspark.sql.types import IntegerType, StringType, FloatType # will have implied column `id` for ordinal of row -testdata_defn = (dg.DataGenerator(spark, name="basic_dataset", rows=100000, partitions=20) - .withColumn("code1", IntegerType(), minValue=1, maxValue=20, step=1) - .withColumn("code2", IntegerType(), maxValue=1000, step=5) - .withColumn("code3", IntegerType(), minValue=100, maxValue=200, step=1, random=True) - .withColumn("xcode", StringType(), values=["a", "test", "value"], random=True) - .withColumn("rating", FloatType(), minValue=1.0, maxValue=5.0, step=0.01, random=True) - .withColumn("non_scaled_rating", FloatType(), minValue=1.0, maxValue=5.0, continuous=True, - random=True)) +testdata_defn = ( + dg.DataGenerator(spark, name="basic_dataset", rows=100000, partitions=20) + .withColumn("code1", IntegerType(), minValue=1, maxValue=20, step=1) + .withColumn("code2", IntegerType(), maxValue=1000, step=5) + .withColumn("code3", IntegerType(), minValue=100, maxValue=200, step=1, random=True) + .withColumn("xcode", StringType(), values=["a", "test", "value"], random=True) + .withColumn("rating", FloatType(), minValue=1.0, maxValue=5.0, step=0.01, random=True) + .withColumn("non_scaled_rating", FloatType(), minValue=1.0, maxValue=5.0, continuous=True, random=True) +) df = testdata_defn.build() @@ -92,42 +93,38 @@ start = datetime(2017, 10, 1, 0, 0, 0) end = datetime(2018, 10, 1, 6, 0, 0) -schema = StructType([ - StructField("site_id", IntegerType(), True), - StructField("site_cd", StringType(), True), - StructField("c", StringType(), True), - StructField("c1", StringType(), True), - StructField("sector_technology_desc", StringType(), True), - -]) +schema = StructType( + [ + StructField("site_id", IntegerType(), True), + StructField("site_cd", StringType(), True), + StructField("c", StringType(), True), + StructField("c1", StringType(), True), + StructField("sector_technology_desc", StringType(), True), + ] +) # build spark session # will have implied column `id` for ordinal of row # number of partitions will control how many Spark tasks the data generation is distributed over -x3 = (dg.DataGenerator(spark, name="my_test_view", rows=1000000, partitions=8) - .withSchema(schema) - # withColumnSpec adds specification for existing column - # here, we speciy data is distributed normally - .withColumnSpec("site_id", minValue=1, maxValue=20, step=1, distribution="normal", random=True) - - # base column specifies dependent column - here the value of site_cd is dependent on the value of site_id - .withColumnSpec("site_cd", prefix='site', baseColumn='site_id') - - # withColumn adds specification for new column - even if the basic data set was initialized from a schema - .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, prefix='status', random=True) - - .withColumn("rand", "float", expr="floor(rand() * 350) * (86400 + 3600)") - - # generate timestamps in over the specified time range - .withColumn("last_sync_dt", "timestamp", begin=start, end=end, interval=interval, random=True) - - # by default all values are populated, but use of percentNulls option introduces nulls randomly - .withColumnSpec("sector_technology_desc", values=["GSM", "UMTS", "LTE", "UNKNOWN"], percentNulls=0.05, - random=True) - .withColumn("test_cell_flg", "integer", values=[0, 1], random=True) - ) +x3 = ( + dg.DataGenerator(spark, name="my_test_view", rows=1000000, partitions=8) + .withSchema(schema) + # withColumnSpec adds specification for existing column + # here, we speciy data is distributed normally + .withColumnSpec("site_id", minValue=1, maxValue=20, step=1, distribution="normal", random=True) + # base column specifies dependent column - here the value of site_cd is dependent on the value of site_id + .withColumnSpec("site_cd", prefix='site', baseColumn='site_id') + # withColumn adds specification for new column - even if the basic data set was initialized from a schema + .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, prefix='status', random=True) + .withColumn("rand", "float", expr="floor(rand() * 350) * (86400 + 3600)") + # generate timestamps in over the specified time range + .withColumn("last_sync_dt", "timestamp", begin=start, end=end, interval=interval, random=True) + # by default all values are populated, but use of percentNulls option introduces nulls randomly + .withColumnSpec("sector_technology_desc", values=["GSM", "UMTS", "LTE", "UNKNOWN"], percentNulls=0.05, random=True) + .withColumn("test_cell_flg", "integer", values=[0, 1], random=True) +) # when we specify ``withTempView`` option, the data is available as view in Scala and SQL code dfOutput = x3.build(withTempView=True) @@ -142,20 +139,20 @@ # COMMAND ---------- # MAGIC %sql -# MAGIC +# MAGIC # MAGIC -- we'll generate row counts by site_id -# MAGIC SELECT site_id, count(*) as row_count_by_site from my_test_view +# MAGIC SELECT site_id, count(*) as row_count_by_site from my_test_view # MAGIC group by site_id # MAGIC order by site_id asc # COMMAND ---------- -# MAGIC %scala -# MAGIC +# MAGIC %scala +# MAGIC # MAGIC val df = spark.sql(""" -# MAGIC SELECT site_id, sector_technology_desc, count(*) as row_count_by_site from my_test_view +# MAGIC SELECT site_id, sector_technology_desc, count(*) as row_count_by_site from my_test_view # MAGIC group by site_id, sector_technology_desc # MAGIC order by site_id,sector_technology_desc asc # MAGIC """) -# MAGIC +# MAGIC # MAGIC display(df) diff --git a/tutorial/2-Basics.py b/tutorial/2-Basics.py index 1d75da72..60a26eea 100644 --- a/tutorial/2-Basics.py +++ b/tutorial/2-Basics.py @@ -5,13 +5,13 @@ # COMMAND ---------- # MAGIC %md ##Generating a simple data set -# MAGIC +# MAGIC # MAGIC Lets look at generating a simple data set as follows: -# MAGIC +# MAGIC # MAGIC - use the `id` field as the key # MAGIC - generate a predictable value for a theoretical field `code1` that will be the same on every run. # MAGIC The field will be generated based on modulo arithmetic on the `id` field to produce a code -# MAGIC - generate random fields `code2`, `code3` and `code4` +# MAGIC - generate random fields `code2`, `code3` and `code4` # MAGIC - generate a `site code` as a string version of the `site id` # MAGIC - generate a sector status description # MAGIC - generate a communications technology field with a discrete set of values @@ -21,20 +21,19 @@ import dbldatagen as dg # will have implied column `id` for ordinal of row -testdata_generator = (dg.DataGenerator(spark, name="test_dataset1", rows=100000, partitions=20, randomSeedMethod="hash_fieldname") - .withIdOutput() # id column will be emitted in the output - .withColumn("code1", "integer", minValue=1, maxValue=20, step=1) - .withColumn("code2", "integer", minValue=1, maxValue=20, step=1, random=True) - .withColumn("code3", "integer", minValue=1, maxValue=20, step=1, random=True) - .withColumn("code4", "integer", minValue=1, maxValue=20, step=1, random=True) - # base column specifies dependent column - - .withColumn("site_cd", "string", prefix='site', baseColumn='code1') - .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, prefix='status', - random=True) - .withColumn("tech", "string", values=["GSM", "UMTS", "LTE", "UNKNOWN"], random=True) - .withColumn("test_cell_flg", "integer", values=[0, 1], random=True) - ) +testdata_generator = ( + dg.DataGenerator(spark, name="test_dataset1", rows=100000, partitions=20, randomSeedMethod="hash_fieldname") + .withIdOutput() # id column will be emitted in the output + .withColumn("code1", "integer", minValue=1, maxValue=20, step=1) + .withColumn("code2", "integer", minValue=1, maxValue=20, step=1, random=True) + .withColumn("code3", "integer", minValue=1, maxValue=20, step=1, random=True) + .withColumn("code4", "integer", minValue=1, maxValue=20, step=1, random=True) + # base column specifies dependent column + .withColumn("site_cd", "string", prefix='site', baseColumn='code1') + .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, prefix='status', random=True) + .withColumn("tech", "string", values=["GSM", "UMTS", "LTE", "UNKNOWN"], random=True) + .withColumn("test_cell_flg", "integer", values=[0, 1], random=True) +) df = testdata_generator.build() # build our dataset @@ -47,7 +46,7 @@ # COMMAND ---------- # MAGIC %md ### Controlling the starting ID -# MAGIC +# MAGIC # MAGIC Often when we are generating test data, we want multiple data sets and to control how keys are generated for datasets after the first. We can control the generation of the `id` field by specifing the starting `id` - for example to simulate arrival of new data rows # COMMAND ---------- @@ -61,9 +60,9 @@ # COMMAND ---------- # MAGIC %md ### Using weights -# MAGIC -# MAGIC In many cases when we have a series of values for a column, they are not distributed uniformly. By specifying weights, the frequency of generation of specific values can be weighted. -# MAGIC +# MAGIC +# MAGIC In many cases when we have a series of values for a column, they are not distributed uniformly. By specifying weights, the frequency of generation of specific values can be weighted. +# MAGIC # MAGIC For example: # COMMAND ---------- @@ -71,22 +70,19 @@ import dbldatagen as dg # will have implied column `id` for ordinal of row -testdata_generator2 = (dg.DataGenerator(spark, name="test_dataset2", rows=100000, partitions=20, - randomSeedMethod="hash_fieldname") - .withIdOutput() # id column will be emitted in the output - .withColumn("code1", "integer", minValue=1, maxValue=20, step=1) - .withColumn("code2", "integer", minValue=1, maxValue=20, step=1, random=True) - .withColumn("code3", "integer", minValue=1, maxValue=20, step=1, random=True) - .withColumn("code4", "integer", minValue=1, maxValue=20, step=1, random=True) - # base column specifies dependent column - - .withColumn("site_cd", "string", prefix='site', baseColumn='code1') - .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, prefix='status', - random=True) - .withColumn("tech", "string", values=["GSM", "UMTS", "LTE", "UNKNOWN"], weights=[5, 1, 1, 1], - random=True) - .withColumn("test_cell_flg", "integer", values=[0, 1], random=True) - ) +testdata_generator2 = ( + dg.DataGenerator(spark, name="test_dataset2", rows=100000, partitions=20, randomSeedMethod="hash_fieldname") + .withIdOutput() # id column will be emitted in the output + .withColumn("code1", "integer", minValue=1, maxValue=20, step=1) + .withColumn("code2", "integer", minValue=1, maxValue=20, step=1, random=True) + .withColumn("code3", "integer", minValue=1, maxValue=20, step=1, random=True) + .withColumn("code4", "integer", minValue=1, maxValue=20, step=1, random=True) + # base column specifies dependent column + .withColumn("site_cd", "string", prefix='site', baseColumn='code1') + .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, prefix='status', random=True) + .withColumn("tech", "string", values=["GSM", "UMTS", "LTE", "UNKNOWN"], weights=[5, 1, 1, 1], random=True) + .withColumn("test_cell_flg", "integer", values=[0, 1], random=True) +) df3 = testdata_generator2.build() # build our dataset @@ -95,7 +91,7 @@ # COMMAND ---------- # MAGIC %md ### Generating timestamps -# MAGIC +# MAGIC # MAGIC In many cases when testing ingest pipelines, our test data needs simulated dates or timestamps to simulate the time, date or timestamp to simulate time of ingest or extraction. # COMMAND ---------- @@ -108,33 +104,35 @@ start = datetime(2017, 10, 1, 0, 0, 0) end = datetime(2018, 10, 1, 6, 0, 0) -schema = StructType([ - StructField("site_id", IntegerType(), True), - StructField("site_cd", StringType(), True), - StructField("c", StringType(), True), - StructField("c1", StringType(), True), - StructField("sector_technology_desc", StringType(), True), - -]) +schema = StructType( + [ + StructField("site_id", IntegerType(), True), + StructField("site_cd", StringType(), True), + StructField("c", StringType(), True), + StructField("c1", StringType(), True), + StructField("sector_technology_desc", StringType(), True), + ] +) # build spark session # will have implied column `id` for ordinal of row -ds = (dg.DataGenerator(spark, name="association_oss_cell_info", rows=100000, partitions=20) - .withSchema(schema) - # withColumnSpec adds specification for existing column - .withColumnSpec("site_id", minValue=1, maxValue=20, step=1) - # base column specifies dependent column - .withIdOutput() - .withColumnSpec("site_cd", prefix='site', baseColumn='site_id') - .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, prefix='status', random=True) - # withColumn adds specification for new column - .withColumn("rand", "float", expr="floor(rand() * 350) * (86400 + 3600)") - .withColumn("last_sync_dt", "timestamp", begin=start, end=end, interval=interval, random=True) - .withColumnSpec("sector_technology_desc", values=["GSM", "UMTS", "LTE", "UNKNOWN"], random=True) - .withColumn("test_cell_flg", "integer", values=[0, 1], random=True) - ) +ds = ( + dg.DataGenerator(spark, name="association_oss_cell_info", rows=100000, partitions=20) + .withSchema(schema) + # withColumnSpec adds specification for existing column + .withColumnSpec("site_id", minValue=1, maxValue=20, step=1) + # base column specifies dependent column + .withIdOutput() + .withColumnSpec("site_cd", prefix='site', baseColumn='site_id') + .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, prefix='status', random=True) + # withColumn adds specification for new column + .withColumn("rand", "float", expr="floor(rand() * 350) * (86400 + 3600)") + .withColumn("last_sync_dt", "timestamp", begin=start, end=end, interval=interval, random=True) + .withColumnSpec("sector_technology_desc", values=["GSM", "UMTS", "LTE", "UNKNOWN"], random=True) + .withColumn("test_cell_flg", "integer", values=[0, 1], random=True) +) df = ds.build() @@ -154,28 +152,29 @@ partitions_requested = 8 data_rows = 10000000 -spark.sql("""Create table if not exists test_vehicle_data( +spark.sql( + """Create table if not exists test_vehicle_data( name string, serial_number string, license_plate string, email string - ) using Delta""") + ) using Delta""" +) table_schema = spark.table("test_vehicle_data").schema print(table_schema) - -dataspec = (dg.DataGenerator(spark, rows=10000000, partitions=8, - randomSeedMethod="hash_fieldname") - .withSchema(table_schema)) - -dataspec = (dataspec - .withColumnSpec("name", percentNulls=0.01, template=r'\\w \\w|\\w a. \\w') - .withColumnSpec("serial_number", minValue=1000000, maxValue=10000000, - prefix="dr", random=True) - .withColumnSpec("email", template=r'\\w.\\w@\\w.com') - .withColumnSpec("license_plate", template=r'\\n-\\n') - ) + +dataspec = dg.DataGenerator(spark, rows=10000000, partitions=8, randomSeedMethod="hash_fieldname").withSchema( + table_schema +) + +dataspec = ( + dataspec.withColumnSpec("name", percentNulls=0.01, template=r'\\w \\w|\\w a. \\w') + .withColumnSpec("serial_number", minValue=1000000, maxValue=10000000, prefix="dr", random=True) + .withColumnSpec("email", template=r'\\w.\\w@\\w.com') + .withColumnSpec("license_plate", template=r'\\n-\\n') +) df1 = dataspec.build() display(df1) @@ -197,53 +196,81 @@ spark.conf.set("spark.sql.shuffle.partitions", shuffle_partitions_requested) -country_codes = ['CN', 'US', 'FR', 'CA', 'IN', 'JM', 'IE', 'PK', 'GB', 'IL', 'AU', 'SG', - 'ES', 'GE', 'MX', 'ET', 'SA', 'LB', 'NL'] -country_weights = [1300, 365, 67, 38, 1300, 3, 7, 212, 67, 9, 25, 6, 47, 83, 126, 109, 58, 8, - 17] +country_codes = [ + 'CN', + 'US', + 'FR', + 'CA', + 'IN', + 'JM', + 'IE', + 'PK', + 'GB', + 'IL', + 'AU', + 'SG', + 'ES', + 'GE', + 'MX', + 'ET', + 'SA', + 'LB', + 'NL', +] +country_weights = [1300, 365, 67, 38, 1300, 3, 7, 212, 67, 9, 25, 6, 47, 83, 126, 109, 58, 8, 17] manufacturers = ['Delta corp', 'Xyzzy Inc.', 'Lakehouse Ltd', 'Acme Corp', 'Embanks Devices'] lines = ['delta', 'xyzzy', 'lakehouse', 'gadget', 'droid'] -testDataSpec = (dg.DataGenerator(spark, name="device_data_set", rows=data_rows, - partitions=partitions_requested, randomSeedMethod='hash_fieldname') - .withIdOutput() - # we'll use hash of the base field to generate the ids to - # avoid a simple incrementing sequence - .withColumn("internal_device_id", LongType(), minValue=0x1000000000000, - uniqueValues=device_population, omit=True, baseColumnType="hash") - - # note for format strings, we must use "%lx" not "%x" as the - # underlying value is a long - .withColumn("device_id", StringType(), format="0x%013x", - baseColumn="internal_device_id") - - # the device / user attributes will be the same for the same device id - # so lets use the internal device id as the base column for these attribute - .withColumn("country", StringType(), values=country_codes, - weights=country_weights, - baseColumn="internal_device_id") - .withColumn("manufacturer", StringType(), values=manufacturers, - baseColumn="internal_device_id") - - # use omit = True if you don't want a column to appear in the final output - # but just want to use it as part of generation of another column - .withColumn("line", StringType(), values=lines, baseColumn="manufacturer", - baseColumnType="hash", omit=True) - .withColumn("model_ser", IntegerType(), minValue=1, maxValue=11, - baseColumn="device_id", - baseColumnType="hash", omit=True) - - .withColumn("model_line", StringType(), expr="concat(line, '#', model_ser)", - baseColumn=["line", "model_ser"]) - .withColumn("event_type", StringType(), - values=["activation", "deactivation", "plan change", - "telecoms activity", "internet activity", "device error"], - random=True) - .withColumn("event_ts", "timestamp", begin="2020-01-01 01:00:00", end="2020-12-31 23:59:00", interval="1 minute", random=True) - - ) +testDataSpec = ( + dg.DataGenerator( + spark, + name="device_data_set", + rows=data_rows, + partitions=partitions_requested, + randomSeedMethod='hash_fieldname', + ) + .withIdOutput() + # we'll use hash of the base field to generate the ids to + # avoid a simple incrementing sequence + .withColumn( + "internal_device_id", + LongType(), + minValue=0x1000000000000, + uniqueValues=device_population, + omit=True, + baseColumnType="hash", + ) + # note for format strings, we must use "%lx" not "%x" as the + # underlying value is a long + .withColumn("device_id", StringType(), format="0x%013x", baseColumn="internal_device_id") + # the device / user attributes will be the same for the same device id + # so lets use the internal device id as the base column for these attribute + .withColumn("country", StringType(), values=country_codes, weights=country_weights, baseColumn="internal_device_id") + .withColumn("manufacturer", StringType(), values=manufacturers, baseColumn="internal_device_id") + # use omit = True if you don't want a column to appear in the final output + # but just want to use it as part of generation of another column + .withColumn("line", StringType(), values=lines, baseColumn="manufacturer", baseColumnType="hash", omit=True) + .withColumn( + "model_ser", IntegerType(), minValue=1, maxValue=11, baseColumn="device_id", baseColumnType="hash", omit=True + ) + .withColumn("model_line", StringType(), expr="concat(line, '#', model_ser)", baseColumn=["line", "model_ser"]) + .withColumn( + "event_type", + StringType(), + values=["activation", "deactivation", "plan change", "telecoms activity", "internet activity", "device error"], + random=True, + ) + .withColumn( + "event_ts", + "timestamp", + begin="2020-01-01 01:00:00", + end="2020-12-31 23:59:00", + interval="1 minute", + random=True, + ) +) dfTestData = testDataSpec.build() @@ -252,7 +279,7 @@ # COMMAND ---------- # MAGIC %md ##Generating Streaming Data -# MAGIC +# MAGIC # MAGIC We can use the specs from the previous exampled to generate streaming data # COMMAND ---------- @@ -273,54 +300,75 @@ data_rows = 20 * 100000 partitions_requested = 8 -country_codes = ['CN', 'US', 'FR', 'CA', 'IN', 'JM', 'IE', 'PK', 'GB', 'IL', 'AU', 'SG', - 'ES', 'GE', 'MX', 'ET', 'SA', 'LB', 'NL'] -country_weights = [1300, 365, 67, 38, 1300, 3, 7, 212, 67, 9, 25, 6, 47, 83, 126, 109, 58, 8, - 17] +country_codes = [ + 'CN', + 'US', + 'FR', + 'CA', + 'IN', + 'JM', + 'IE', + 'PK', + 'GB', + 'IL', + 'AU', + 'SG', + 'ES', + 'GE', + 'MX', + 'ET', + 'SA', + 'LB', + 'NL', +] +country_weights = [1300, 365, 67, 38, 1300, 3, 7, 212, 67, 9, 25, 6, 47, 83, 126, 109, 58, 8, 17] manufacturers = ['Delta corp', 'Xyzzy Inc.', 'Lakehouse Ltd', 'Acme Corp', 'Embanks Devices'] lines = ['delta', 'xyzzy', 'lakehouse', 'gadget', 'droid'] -testDataSpec = (dg.DataGenerator(spark, name="device_data_set", rows=data_rows, - partitions=partitions_requested, randomSeedMethod='hash_fieldname', - verbose=True) - .withIdOutput() - # we'll use hash of the base field to generate the ids to - # avoid a simple incrementing sequence - .withColumn("internal_device_id", LongType(), minValue=0x1000000000000, - uniqueValues=device_population, omit=True, baseColumnType="hash") - - # note for format strings, we must use "%lx" not "%x" as the - # underlying value is a long - .withColumn("device_id", StringType(), format="0x%013x", - baseColumn="internal_device_id") - - # the device / user attributes will be the same for the same device id - # so lets use the internal device id as the base column for these attribute - .withColumn("country", StringType(), values=country_codes, - weights=country_weights, - baseColumn="internal_device_id") - .withColumn("manufacturer", StringType(), values=manufacturers, - baseColumn="internal_device_id") - - # use omit = True if you don't want a column to appear in the final output - # but just want to use it as part of generation of another column - .withColumn("line", StringType(), values=lines, baseColumn="manufacturer", - baseColumnType="hash", omit=True) - .withColumn("model_ser", IntegerType(), minValue=1, maxValue=11, - baseColumn="device_id", - baseColumnType="hash", omit=True) - - .withColumn("model_line", StringType(), expr="concat(line, '#', model_ser)", - baseColumn=["line", "model_ser"]) - .withColumn("event_type", StringType(), - values=["activation", "deactivation", "plan change", - "telecoms activity", "internet activity", "device error"], - random=True) - .withColumn("event_ts", "timestamp", expr="now()") - - ) +testDataSpec = ( + dg.DataGenerator( + spark, + name="device_data_set", + rows=data_rows, + partitions=partitions_requested, + randomSeedMethod='hash_fieldname', + verbose=True, + ) + .withIdOutput() + # we'll use hash of the base field to generate the ids to + # avoid a simple incrementing sequence + .withColumn( + "internal_device_id", + LongType(), + minValue=0x1000000000000, + uniqueValues=device_population, + omit=True, + baseColumnType="hash", + ) + # note for format strings, we must use "%lx" not "%x" as the + # underlying value is a long + .withColumn("device_id", StringType(), format="0x%013x", baseColumn="internal_device_id") + # the device / user attributes will be the same for the same device id + # so lets use the internal device id as the base column for these attribute + .withColumn("country", StringType(), values=country_codes, weights=country_weights, baseColumn="internal_device_id") + .withColumn("manufacturer", StringType(), values=manufacturers, baseColumn="internal_device_id") + # use omit = True if you don't want a column to appear in the final output + # but just want to use it as part of generation of another column + .withColumn("line", StringType(), values=lines, baseColumn="manufacturer", baseColumnType="hash", omit=True) + .withColumn( + "model_ser", IntegerType(), minValue=1, maxValue=11, baseColumn="device_id", baseColumnType="hash", omit=True + ) + .withColumn("model_line", StringType(), expr="concat(line, '#', model_ser)", baseColumn=["line", "model_ser"]) + .withColumn( + "event_type", + StringType(), + values=["activation", "deactivation", "plan change", "telecoms activity", "internet activity", "device error"], + random=True, + ) + .withColumn("event_ts", "timestamp", expr="now()") +) dfTestDataStreaming = testDataSpec.build(withStreaming=True, options={'rowsPerSecond': 500}) @@ -348,31 +396,33 @@ start = datetime(2017, 10, 1, 0, 0, 0) end = datetime(2018, 10, 1, 6, 0, 0) -schema = StructType([ - StructField("site_id", IntegerType(), True), - StructField("site_cd", StringType(), True), - StructField("c", StringType(), True), - StructField("c1", StringType(), True), - StructField("sector_technology_desc", StringType(), True), - -]) +schema = StructType( + [ + StructField("site_id", IntegerType(), True), + StructField("site_cd", StringType(), True), + StructField("c", StringType(), True), + StructField("c1", StringType(), True), + StructField("sector_technology_desc", StringType(), True), + ] +) # will have implied column `id` for ordinal of row -ds = (dg.DataGenerator(spark, name="association_oss_cell_info", rows=100000, partitions=20) - .withSchema(schema) - # withColumnSpec adds specification for existing column - .withColumnSpec("site_id", minValue=1, maxValue=20, step=1) - # base column specifies dependent column - .withIdOutput() - .withColumnSpec("site_cd", prefix='site', baseColumn='site_id') - .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, prefix='status', random=True) - # withColumn adds specification for new column - .withColumn("rand", "float", expr="floor(rand() * 350) * (86400 + 3600)") - .withColumn("last_sync_dt", "timestamp", begin=start, end=end, interval=interval, random=True) - .withColumnSpec("sector_technology_desc", values=["GSM", "UMTS", "LTE", "UNKNOWN"], random=True) - .withColumn("test_cell_flg", "integer", values=[0, 1], random=True) - ) +ds = ( + dg.DataGenerator(spark, name="association_oss_cell_info", rows=100000, partitions=20) + .withSchema(schema) + # withColumnSpec adds specification for existing column + .withColumnSpec("site_id", minValue=1, maxValue=20, step=1) + # base column specifies dependent column + .withIdOutput() + .withColumnSpec("site_cd", prefix='site', baseColumn='site_id') + .withColumn("sector_status_desc", "string", minValue=1, maxValue=200, step=1, prefix='status', random=True) + # withColumn adds specification for new column + .withColumn("rand", "float", expr="floor(rand() * 350) * (86400 + 3600)") + .withColumn("last_sync_dt", "timestamp", begin=start, end=end, interval=interval, random=True) + .withColumnSpec("sector_technology_desc", values=["GSM", "UMTS", "LTE", "UNKNOWN"], random=True) + .withColumn("test_cell_flg", "integer", values=[0, 1], random=True) +) df = ds.build(withStreaming=True, options={'rowsPerSecond': 500}) diff --git a/tutorial/3-ChangeDataCapture-example.py b/tutorial/3-ChangeDataCapture-example.py index 3d6daa4c..696451ce 100644 --- a/tutorial/3-ChangeDataCapture-example.py +++ b/tutorial/3-ChangeDataCapture-example.py @@ -9,9 +9,9 @@ # COMMAND ---------- # MAGIC %md #### Overview -# MAGIC -# MAGIC We'll generate a customer table, and write out the data. -# MAGIC +# MAGIC +# MAGIC We'll generate a customer table, and write out the data. +# MAGIC # MAGIC Then we generate changes for the table and show merging them in. # COMMAND ---------- @@ -28,8 +28,8 @@ # COMMAND ---------- -# MAGIC %md Lets generate 10 million customers -# MAGIC +# MAGIC %md Lets generate 10 million customers +# MAGIC # MAGIC We'll add a timestamp for when the row was generated and a memo field to mark what operation added it # COMMAND ---------- @@ -48,29 +48,44 @@ uniqueCustomers = 10 * 1000000 -dataspec = (dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested) - .withColumn("customer_id", "long", uniqueValues=uniqueCustomers) - .withColumn("name", percentNulls=0.01, template=r'\\w \\w|\\w a. \\w') - .withColumn("alias", percentNulls=0.01, template=r'\\w \\w|\\w a. \\w') - .withColumn("payment_instrument_type", values=['paypal', 'Visa', 'Mastercard', - 'American Express', 'discover', 'branded visa', - 'branded mastercard'], - random=True, distribution="normal") - .withColumn("int_payment_instrument", "int", minValue=0000, maxValue=9999, baseColumn="customer_id", - baseColumnType="hash", omit=True) - .withColumn("payment_instrument", expr="format_number(int_payment_instrument, '**** ****** *####')", - baseColumn="int_payment_instrument") - .withColumn("email", template=r'\\w.\\w@\\w.com|\\w-\\w@\\w') - .withColumn("email2", template=r'\\w.\\w@\\w.com') - .withColumn("ip_address", template=r'\\n.\\n.\\n.\\n') - .withColumn("md5_payment_instrument", - expr="md5(concat(payment_instrument_type, ':', payment_instrument))", - base_column=['payment_instrument_type', 'payment_instrument']) - .withColumn("customer_notes", text=dg.ILText(words=(1,8))) - .withColumn("created_ts", "timestamp", expr="now()") - .withColumn("modified_ts", "timestamp", expr="now()") - .withColumn("memo", expr="'original data'") - ) +dataspec = ( + dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested) + .withColumn("customer_id", "long", uniqueValues=uniqueCustomers) + .withColumn("name", percentNulls=0.01, template=r'\\w \\w|\\w a. \\w') + .withColumn("alias", percentNulls=0.01, template=r'\\w \\w|\\w a. \\w') + .withColumn( + "payment_instrument_type", + values=['paypal', 'Visa', 'Mastercard', 'American Express', 'discover', 'branded visa', 'branded mastercard'], + random=True, + distribution="normal", + ) + .withColumn( + "int_payment_instrument", + "int", + minValue=0000, + maxValue=9999, + baseColumn="customer_id", + baseColumnType="hash", + omit=True, + ) + .withColumn( + "payment_instrument", + expr="format_number(int_payment_instrument, '**** ****** *####')", + baseColumn="int_payment_instrument", + ) + .withColumn("email", template=r'\\w.\\w@\\w.com|\\w-\\w@\\w') + .withColumn("email2", template=r'\\w.\\w@\\w.com') + .withColumn("ip_address", template=r'\\n.\\n.\\n.\\n') + .withColumn( + "md5_payment_instrument", + expr="md5(concat(payment_instrument_type, ':', payment_instrument))", + base_column=['payment_instrument_type', 'payment_instrument'], + ) + .withColumn("customer_notes", text=dg.ILText(words=(1, 8))) + .withColumn("created_ts", "timestamp", expr="now()") + .withColumn("modified_ts", "timestamp", expr="now()") + .withColumn("memo", expr="'original data'") +) df1 = dataspec.build() # write table @@ -84,57 +99,57 @@ # COMMAND ---------- customers1_location = BASE_PATH + "customers1" -tableDefn=dataspec.scriptTable(name="customers1", location=customers1_location) +tableDefn = dataspec.scriptTable(name="customers1", location=customers1_location) spark.sql(tableDefn) # COMMAND ---------- -# MAGIC %sql +# MAGIC %sql # MAGIC -- lets check our table -# MAGIC +# MAGIC # MAGIC select * from customers1 # COMMAND ---------- # MAGIC %md ### Changes -# MAGIC +# MAGIC # MAGIC Lets generate some changes # COMMAND ---------- import pyspark.sql.functions as F -start_of_new_ids = df1.select(F.max('customer_id')+1).collect()[0][0] +start_of_new_ids = df1.select(F.max('customer_id') + 1).collect()[0][0] print(start_of_new_ids) # todo - as sequence for random columns will restart from previous seeds , you will get repeated values on next generation operation # want to use seed sequences so that new random data is not same as old data from previous runs -df1_inserts = (dataspec.clone() - .option("startingId", start_of_new_ids) - .withRowCount(10 * 1000) - .build() - .withColumn("memo", F.lit("insert")) - .withColumn("customer_id", F.expr(f"customer_id + {start_of_new_ids}")) - ) +df1_inserts = ( + dataspec.clone() + .option("startingId", start_of_new_ids) + .withRowCount(10 * 1000) + .build() + .withColumn("memo", F.lit("insert")) + .withColumn("customer_id", F.expr(f"customer_id + {start_of_new_ids}")) +) # read the written data - if we simply recompute, timestamps of original will be lost df_original = spark.read.format("delta").load(customers1_location) -df1_updates = (df_original.sample(False, 0.1) - .limit(50 * 1000) - .withColumn("alias", F.lit('modified alias')) - .withColumn("modified_ts", F.expr('now()')) - .withColumn("memo", F.lit("update"))) +df1_updates = ( + df_original.sample(False, 0.1) + .limit(50 * 1000) + .withColumn("alias", F.lit('modified alias')) + .withColumn("modified_ts", F.expr('now()')) + .withColumn("memo", F.lit("update")) +) df_changes = df1_inserts.union(df1_updates) # randomize ordering -df_changes = (df_changes.withColumn("order_rand", F.expr("rand()")) - .orderBy("order_rand") - .drop("order_rand") - ) +df_changes = df_changes.withColumn("order_rand", F.expr("rand()")).orderBy("order_rand").drop("order_rand") display(df_changes) @@ -142,18 +157,19 @@ # COMMAND ---------- # MAGIC %md ###Now lets merge in the changes -# MAGIC +# MAGIC # MAGIC We can script the merge statement in the data generator # COMMAND ---------- df_changes.dropDuplicates(["customer_id"]).createOrReplaceTempView("customers1_changes") -sqlStmt = dataspec.scriptMerge(tgtName="customers1", srcName="customers1_changes", - joinExpr="src.customer_id=tgt.customer_id", - updateColumns=["alias", "memo","modified_ts"], - updateColumnExprs=[ ("memo", "'updated on merge'"), - ("modified_ts", "now()") - ]) +sqlStmt = dataspec.scriptMerge( + tgtName="customers1", + srcName="customers1_changes", + joinExpr="src.customer_id=tgt.customer_id", + updateColumns=["alias", "memo", "modified_ts"], + updateColumnExprs=[("memo", "'updated on merge'"), ("modified_ts", "now()")], +) print(sqlStmt) @@ -165,16 +181,16 @@ # COMMAND ---------- -# MAGIC %sql +# MAGIC %sql # MAGIC -- lets check our table for updates -# MAGIC +# MAGIC # MAGIC select * from customers1 where created_ts != modified_ts # COMMAND ---------- -# MAGIC %sql +# MAGIC %sql # MAGIC -- lets check our table for inserts -# MAGIC +# MAGIC # MAGIC select * from customers1 where memo = "insert" # COMMAND ---------- diff --git a/tutorial/4-Generating-multi-table-data.py b/tutorial/4-Generating-multi-table-data.py index 9fb1c9f1..9b6cb780 100644 --- a/tutorial/4-Generating-multi-table-data.py +++ b/tutorial/4-Generating-multi-table-data.py @@ -5,30 +5,30 @@ # COMMAND ---------- # MAGIC %md ##Multi table data generation -# MAGIC +# MAGIC # MAGIC To illustrate multi-table data generation and use, we'll use a simplified version of telecoms billing processes. -# MAGIC +# MAGIC # MAGIC Let's assume we have data as follows: -# MAGIC +# MAGIC # MAGIC - A set of customers -# MAGIC - A set of customer device activity events +# MAGIC - A set of customer device activity events # MAGIC - text message # MAGIC - local call # MAGIC - international call # MAGIC - long distance call # MAGIC - internet activity -# MAGIC -# MAGIC - a set of pricing plans indicating +# MAGIC +# MAGIC - a set of pricing plans indicating # MAGIC - cost per MB of internet activity # MAGIC - cost per minute of call for each of the call categories # MAGIC - cost per message -# MAGIC +# MAGIC # MAGIC Internet activitity will be priced per MB transferred -# MAGIC +# MAGIC # MAGIC Phone calls will be priced per minute or partial minute. -# MAGIC +# MAGIC # MAGIC Messages will be priced per actual counts -# MAGIC +# MAGIC # MAGIC For simplicitity, we'll ignore the free data, messages and calls threshold in most plans and the complexity # MAGIC of matching devices to customers and telecoms operators - our goal here is to show generation of join # MAGIC ready data, rather than full modelling of phone usage invoicing. @@ -41,20 +41,22 @@ import re -MARGIN_PATTERN= re.compile(r"\s*\|") # margin detection pattern for stripMargin -def stripMargin(s): - """ strip margin removes leading space in multi line string before '|' """ - return "\n".join(re.split(MARGIN_PATTERN, s)) +MARGIN_PATTERN = re.compile(r"\s*\|") # margin detection pattern for stripMargin + + +def stripMargin(s): + """strip margin removes leading space in multi line string before '|'""" + return "\n".join(re.split(MARGIN_PATTERN, s)) # COMMAND ---------- # MAGIC %md ### Let's model our plans -# MAGIC +# MAGIC # MAGIC Note, we use two columns, `ld_multipler` and `intl_multiplier` just as intermediate results used in later calculations and omit them from the output. -# MAGIC +# MAGIC # MAGIC We use `decimal` types for prices to avoid rounding issues. -# MAGIC +# MAGIC # MAGIC Here we use a simple sequence for our plan ids. # COMMAND ---------- @@ -69,36 +71,56 @@ def stripMargin(s): shuffle_partitions_requested = 8 partitions_requested = 1 -data_rows = UNIQUE_PLANS # we'll generate one row for each plan +data_rows = UNIQUE_PLANS # we'll generate one row for each plan spark.conf.set("spark.sql.shuffle.partitions", shuffle_partitions_requested) spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 20000) -plan_dataspec = (dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested) - .withColumn("plan_id","int", minValue=PLAN_MIN_VALUE, uniqueValues=UNIQUE_PLANS) - .withColumn("plan_name", prefix="plan", baseColumn="plan_id") # use plan_id as root value - - # note default step is 1 so you must specify a step for small number ranges, - .withColumn("cost_per_mb", "decimal(5,3)", minValue=0.005, maxValue=0.050, step=0.005, random=True) - .withColumn("cost_per_message", "decimal(5,3)", minValue=0.001, maxValue=0.02, step=0.001, random=True) - .withColumn("cost_per_minute", "decimal(5,3)", minValue=0.001, maxValue=0.01, step=0.001, random=True) - - # we're modelling long distance and international prices simplistically - each is a multiplier thats applied to base rate - .withColumn("ld_multiplier", "decimal(5,3)", minValue=1.5, maxValue=3, step=0.05, random=True, - distribution="normal", omit=True) - .withColumn("ld_cost_per_minute", "decimal(5,3)", expr="cost_per_minute * ld_multiplier", - baseColumns=['cost_per_minute', 'ld_multiplier']) - .withColumn("intl_multiplier", "decimal(5,3)", minValue=2, maxValue=4, step=0.05, random=True, - distribution="normal", omit=True) - .withColumn("intl_cost_per_minute", "decimal(5,3)", expr="cost_per_minute * intl_multiplier", - baseColumns=['cost_per_minute', 'intl_multiplier']) - - ) -df_plans = (plan_dataspec.build() - .cache() - ) +plan_dataspec = ( + dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested) + .withColumn("plan_id", "int", minValue=PLAN_MIN_VALUE, uniqueValues=UNIQUE_PLANS) + .withColumn("plan_name", prefix="plan", baseColumn="plan_id") # use plan_id as root value + # note default step is 1 so you must specify a step for small number ranges, + .withColumn("cost_per_mb", "decimal(5,3)", minValue=0.005, maxValue=0.050, step=0.005, random=True) + .withColumn("cost_per_message", "decimal(5,3)", minValue=0.001, maxValue=0.02, step=0.001, random=True) + .withColumn("cost_per_minute", "decimal(5,3)", minValue=0.001, maxValue=0.01, step=0.001, random=True) + # we're modelling long distance and international prices simplistically - each is a multiplier thats applied to base rate + .withColumn( + "ld_multiplier", + "decimal(5,3)", + minValue=1.5, + maxValue=3, + step=0.05, + random=True, + distribution="normal", + omit=True, + ) + .withColumn( + "ld_cost_per_minute", + "decimal(5,3)", + expr="cost_per_minute * ld_multiplier", + baseColumns=['cost_per_minute', 'ld_multiplier'], + ) + .withColumn( + "intl_multiplier", + "decimal(5,3)", + minValue=2, + maxValue=4, + step=0.05, + random=True, + distribution="normal", + omit=True, + ) + .withColumn( + "intl_cost_per_minute", + "decimal(5,3)", + expr="cost_per_minute * intl_multiplier", + baseColumns=['cost_per_minute', 'intl_multiplier'], + ) +) +df_plans = plan_dataspec.build().cache() display(df_plans) @@ -106,19 +128,19 @@ def stripMargin(s): # COMMAND ---------- # MAGIC %md ###Lets model our customers -# MAGIC -# MAGIC We'll use device id as the foreign key for device events here. -# MAGIC -# MAGIC we want to ensure that our device id is unique for each customer. We could use a simple sequence as with plans but for the purposes of illustration, we'll use a hash of the customer ids instead. -# MAGIC +# MAGIC +# MAGIC We'll use device id as the foreign key for device events here. +# MAGIC +# MAGIC we want to ensure that our device id is unique for each customer. We could use a simple sequence as with plans but for the purposes of illustration, we'll use a hash of the customer ids instead. +# MAGIC # MAGIC There's still a small likelihood of hash collisions so we'll remove any duplicates from the generated data - but in practice, we do not see duplicates in most datasets when using hashing. As all data produced by the framework is repeatable when not using random , or when using random with a seed, this will give us a predictable range of foreign keys. -# MAGIC +# MAGIC # MAGIC Use of hashes and sequences is a very efficient way of generating unique predictable keys while introducing some pseudo-randomness in the values. -# MAGIC -# MAGIC +# MAGIC +# MAGIC # MAGIC Note - for real telephony systems, theres a complex set of rules around device ids (IMEI and related device ids), subscriber numbers and techniques for matching devices to subscribers. Again, our goal here is to illustrate generating predictable join keys not full modelling of a telephony system. -# MAGIC -# MAGIC We use decimal types for ids to avoid exceeding the range of ints and longs when working with larger numbers of customers. +# MAGIC +# MAGIC We use decimal types for ids to avoid exceeding the range of ints and longs when working with larger numbers of customers. # COMMAND ---------- @@ -139,57 +161,62 @@ def stripMargin(s): partitions_requested = 8 data_rows = UNIQUE_CUSTOMERS -customer_dataspec = (dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested) - .withColumn("customer_id","decimal(10)", minValue=CUSTOMER_MIN_VALUE, uniqueValues=UNIQUE_CUSTOMERS) - .withColumn("customer_name", template=r"\\w \\w|\\w a. \\w") - - # use the following for a simple sequence - #.withColumn("device_id","decimal(10)", minValue=DEVICE_MIN_VALUE, uniqueValues=UNIQUE_CUSTOMERS) - - .withColumn("device_id","decimal(10)", minValue=DEVICE_MIN_VALUE, - baseColumn="customer_id", baseColumnType="hash") - - .withColumn("phone_number","decimal(10)", minValue=SUBSCRIBER_NUM_MIN_VALUE, - baseColumn=["customer_id", "customer_name"], baseColumnType="hash") - - # for email, we'll just use the formatted phone number - .withColumn("email","string", format="subscriber_%s@myoperator.com", baseColumn="phone_number") - .withColumn("plan", "int", minValue=PLAN_MIN_VALUE, uniqueValues=UNIQUE_PLANS, random=True) - ) - -df_customers = (customer_dataspec.build() - .dropDuplicates(["device_id"]) - .dropDuplicates(["phone_number"]) - .orderBy("customer_id") - .cache() - ) +customer_dataspec = ( + dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested) + .withColumn("customer_id", "decimal(10)", minValue=CUSTOMER_MIN_VALUE, uniqueValues=UNIQUE_CUSTOMERS) + .withColumn("customer_name", template=r"\\w \\w|\\w a. \\w") + # use the following for a simple sequence + # .withColumn("device_id","decimal(10)", minValue=DEVICE_MIN_VALUE, uniqueValues=UNIQUE_CUSTOMERS) + .withColumn("device_id", "decimal(10)", minValue=DEVICE_MIN_VALUE, baseColumn="customer_id", baseColumnType="hash") + .withColumn( + "phone_number", + "decimal(10)", + minValue=SUBSCRIBER_NUM_MIN_VALUE, + baseColumn=["customer_id", "customer_name"], + baseColumnType="hash", + ) + # for email, we'll just use the formatted phone number + .withColumn("email", "string", format="subscriber_%s@myoperator.com", baseColumn="phone_number") + .withColumn("plan", "int", minValue=PLAN_MIN_VALUE, uniqueValues=UNIQUE_PLANS, random=True) +) + +df_customers = ( + customer_dataspec.build() + .dropDuplicates(["device_id"]) + .dropDuplicates(["phone_number"]) + .orderBy("customer_id") + .cache() +) effective_customers = df_customers.count() -print(stripMargin(f"""revised customers : {df_customers.count()}, +print( + stripMargin( + f"""revised customers : {df_customers.count()}, | unique customers: {df_customers.select(F.countDistinct('customer_id')).take(1)[0][0]}, | unique device ids: {df_customers.select(F.countDistinct('device_id')).take(1)[0][0]}, - | unique phone numbers: {df_customers.select(F.countDistinct('phone_number')).take(1)[0][0]}""") - ) + | unique phone numbers: {df_customers.select(F.countDistinct('phone_number')).take(1)[0][0]}""" + ) +) display(df_customers) # COMMAND ---------- # MAGIC %md ###Now lets model our device events -# MAGIC -# MAGIC Generating `master-detail` style data is one of the key challenges in data generation for join ready data. -# MAGIC +# MAGIC +# MAGIC Generating `master-detail` style data is one of the key challenges in data generation for join ready data. +# MAGIC # MAGIC What do we mean by `master-detail`? -# MAGIC -# MAGIC This is where the goal is to model data that consists of large grained entities, that are in turn comprised of smaller items. For example invoices and their respective line items follow this pattern. -# MAGIC -# MAGIC IOT data has similar characteristics. Usually you have a series of devices that generate time series style events from their respective systems and subsystems - each data row being an observation of some measure from some subsystem at a point in time. -# MAGIC +# MAGIC +# MAGIC This is where the goal is to model data that consists of large grained entities, that are in turn comprised of smaller items. For example invoices and their respective line items follow this pattern. +# MAGIC +# MAGIC IOT data has similar characteristics. Usually you have a series of devices that generate time series style events from their respective systems and subsystems - each data row being an observation of some measure from some subsystem at a point in time. +# MAGIC # MAGIC Telephony billing activity has characteristics of both IOT data and master detail data. -# MAGIC +# MAGIC # MAGIC For the telephony events, we want to ensure that on average `n` events occur per device per day and that text and internet browsing is more frequent than phone calls. -# MAGIC +# MAGIC # MAGIC A simple approach is simply to multiple the `number of customers` by `number of days in data set` by `average events per day` # COMMAND ---------- @@ -201,7 +228,7 @@ def stripMargin(s): spark.catalog.clearCache() # clear cache so that if we run multiple times to check performance, we're not relying on cache shuffle_partitions_requested = 8 partitions_requested = 8 -NUM_DAYS=31 +NUM_DAYS = 31 MB_100 = 100 * 1000 * 1000 K_1 = 1000 data_rows = AVG_EVENTS_PER_CUSTOMER * UNIQUE_CUSTOMERS * NUM_DAYS @@ -212,48 +239,80 @@ def stripMargin(s): # use random seed method of 'hash_fieldname' for better spread - default in later builds -events_dataspec = (dg.DataGenerator(spark, rows=data_rows, partitions=partitions_requested, randomSeed=42, - randomSeedMethod="hash_fieldname") - # use same logic as per customers dataset to ensure matching keys - but make them random - .withColumn("device_id_base", "decimal(10)", minValue=CUSTOMER_MIN_VALUE, uniqueValues=UNIQUE_CUSTOMERS, - random=True, omit=True) - .withColumn("device_id", "decimal(10)", minValue=DEVICE_MIN_VALUE, - baseColumn="device_id_base", baseColumnType="hash") - - # use specific random seed to get better spread of values - .withColumn("event_type", "string", values=["sms", "internet", "local call", "ld call", "intl call"], - weights=[50, 50, 20, 10, 5], random=True) - - # use Gamma distribution for skew towards short calls - .withColumn("base_minutes","decimal(7,2)", minValue=1.0, maxValue=100.0, step=0.1, - distribution=dg.distributions.Gamma(shape=1.5, scale=2.0), random=True, omit=True) - - # use Gamma distribution for skew towards short transfers - .withColumn("base_bytes_transferred","decimal(12)", minValue=K_1, maxValue=MB_100, - distribution=dg.distributions.Gamma(shape=0.75, scale=2.0), random=True, omit=True) - - .withColumn("minutes", "decimal(7,2)", baseColumn=["event_type", "base_minutes"], - expr=""" +events_dataspec = ( + dg.DataGenerator( + spark, rows=data_rows, partitions=partitions_requested, randomSeed=42, randomSeedMethod="hash_fieldname" + ) + # use same logic as per customers dataset to ensure matching keys - but make them random + .withColumn( + "device_id_base", + "decimal(10)", + minValue=CUSTOMER_MIN_VALUE, + uniqueValues=UNIQUE_CUSTOMERS, + random=True, + omit=True, + ) + .withColumn( + "device_id", "decimal(10)", minValue=DEVICE_MIN_VALUE, baseColumn="device_id_base", baseColumnType="hash" + ) + # use specific random seed to get better spread of values + .withColumn( + "event_type", + "string", + values=["sms", "internet", "local call", "ld call", "intl call"], + weights=[50, 50, 20, 10, 5], + random=True, + ) + # use Gamma distribution for skew towards short calls + .withColumn( + "base_minutes", + "decimal(7,2)", + minValue=1.0, + maxValue=100.0, + step=0.1, + distribution=dg.distributions.Gamma(shape=1.5, scale=2.0), + random=True, + omit=True, + ) + # use Gamma distribution for skew towards short transfers + .withColumn( + "base_bytes_transferred", + "decimal(12)", + minValue=K_1, + maxValue=MB_100, + distribution=dg.distributions.Gamma(shape=0.75, scale=2.0), + random=True, + omit=True, + ) + .withColumn( + "minutes", + "decimal(7,2)", + baseColumn=["event_type", "base_minutes"], + expr=""" case when event_type in ("local call", "ld call", "intl call") then base_minutes else 0 end - """) - .withColumn("bytes_transferred", "decimal(12)", baseColumn=["event_type", "base_bytes_transferred"], - expr=""" + """, + ) + .withColumn( + "bytes_transferred", + "decimal(12)", + baseColumn=["event_type", "base_bytes_transferred"], + expr=""" case when event_type = "internet" then base_bytes_transferred else 0 end - """) - - .withColumn("event_ts", "timestamp", data_range=dg.DateRange("2020-07-01 00:00:00", - "2020-07-31 11:59:59", - "seconds=1"), - random=True) - - ) - -df_events = (events_dataspec.build() - ) + """, + ) + .withColumn( + "event_ts", + "timestamp", + data_range=dg.DateRange("2020-07-01 00:00:00", "2020-07-31 11:59:59", "seconds=1"), + random=True, + ) +) + +df_events = events_dataspec.build() display(df_events) @@ -270,7 +329,7 @@ def stripMargin(s): df_customer_pricing = df_customers.join(df_plans, df_plans.plan_id == df_customers.plan) display(df_customer_pricing) - + # COMMAND ---------- @@ -282,16 +341,25 @@ def stripMargin(s): # lets compute the summary minutes messages and bytes transferred -df_enriched_events = (df_events - .withColumn("message_count", F.expr("case when event_type='sms' then 1 else 0 end")) - .withColumn("ld_minutes", F.expr("case when event_type='ld call' then cast(ceil(minutes) as decimal(18,3)) else 0.0 end")) - .withColumn("local_minutes", F.expr("case when event_type='local call' then cast(ceil(minutes) as decimal(18,3)) else 0.0 end")) - .withColumn("intl_minutes", F.expr("case when event_type='intl call' then cast(ceil(minutes) as decimal(18,3)) else 0.0 end")) - ) +df_enriched_events = ( + df_events.withColumn("message_count", F.expr("case when event_type='sms' then 1 else 0 end")) + .withColumn( + "ld_minutes", F.expr("case when event_type='ld call' then cast(ceil(minutes) as decimal(18,3)) else 0.0 end") + ) + .withColumn( + "local_minutes", + F.expr("case when event_type='local call' then cast(ceil(minutes) as decimal(18,3)) else 0.0 end"), + ) + .withColumn( + "intl_minutes", + F.expr("case when event_type='intl call' then cast(ceil(minutes) as decimal(18,3)) else 0.0 end"), + ) +) df_enriched_events.createOrReplaceTempView("telephony_events") -df_summary = spark.sql("""select device_id, +df_summary = spark.sql( + """select device_id, round(sum(bytes_transferred) / 1000000.0, 3) as total_mb, sum(message_count) as total_messages, sum(ld_minutes) as total_ld_minutes, @@ -301,7 +369,8 @@ def stripMargin(s): from telephony_events group by device_id -""") +""" +) df_summary.createOrReplaceTempView("event_summary") @@ -314,10 +383,12 @@ def stripMargin(s): # COMMAND ---------- -df_customer_summary = (df_customer_pricing.join(df_summary,df_customer_pricing.device_id == df_summary.device_id ) - .createOrReplaceTempView("customer_summary")) +df_customer_summary = df_customer_pricing.join( + df_summary, df_customer_pricing.device_id == df_summary.device_id +).createOrReplaceTempView("customer_summary") -df_invoices = spark.sql("""select *, +df_invoices = spark.sql( + """select *, internet_cost + sms_cost + ld_cost + local_cost + intl_cost as total_invoice from (select customer_id, customer_name, phone_number, email, plan_name, @@ -328,7 +399,8 @@ def stripMargin(s): cast(round(total_messages * cost_per_message, 2) as decimal(18,2)) as sms_cost from customer_summary) -""") +""" +) display(df_invoices) From 9e984eb650c0958356165abaaed480cc3c713ede Mon Sep 17 00:00:00 2001 From: Greg Hansen Date: Mon, 8 Dec 2025 15:31:38 -0500 Subject: [PATCH 2/8] Improve test coverage --- tests/test_distributions.py | 16 ++++++++++++ tests/test_generation_from_data.py | 23 +++++++++++++++++ tests/test_quick_tests.py | 36 +++++++++++++++++++++++++++ tests/test_ranged_values_and_dates.py | 19 ++++++++++++++ tests/test_utils.py | 6 +++++ 5 files changed, 100 insertions(+) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 861538c9..2197a15c 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -486,3 +486,19 @@ def test_exponential_generation_func(self): assert s2 == pytest.approx(0.10, abs=0.05) assert m2 == pytest.approx(0.10, abs=0.05) + + def test_exponential_requires_rate_for_scale(self): + """Ensure accessing scale without a rate produces a clear error.""" + exp = dist.Exponential() + with pytest.raises( + ValueError, match="Cannot compute value for 'scale'; Missing value for 'rate'" + ): + _ = exp.scale + + def test_exponential_requires_rate_for_generation(self): + """Ensure generating samples without a rate produces a clear error.""" + exp = dist.Exponential() + with pytest.raises( + ValueError, match="Cannot compute value for 'scale'; Missing value for 'rate'" + ): + _ = exp.generateNormalizedDistributionSample() diff --git a/tests/test_generation_from_data.py b/tests/test_generation_from_data.py index 19688ac6..717eb083 100644 --- a/tests/test_generation_from_data.py +++ b/tests/test_generation_from_data.py @@ -119,3 +119,26 @@ def test_df_containing_summary(self): summary_df = dg.DataAnalyzer(sparkSession=spark, df=df).summarizeToDF() assert summary_df.count() == 10 + + def test_data_analyzer_requires_dataframe(self): + """Validate that DataAnalyzer cannot be initialized without a DataFrame.""" + with pytest.raises( + ValueError, match="Argument `df` must be supplied when initializing a `DataAnalyzer`" + ): + dg.DataAnalyzer() + + def test_add_measure_to_summary_requires_dataframe(self): + """Validate that _addMeasureToSummary enforces a non-null dfData argument.""" + with pytest.raises( + ValueError, + match="Input DataFrame `dfData` must be supplied when adding measures to a summary", + ): + dg.DataAnalyzer._addMeasureToSummary("measure_name", dfData=None) + + def test_generator_default_attributes_from_type_requires_datatype(self): + """Validate that _generatorDefaultAttributesFromType enforces a DataType instance.""" + with pytest.raises( + ValueError, + match=r"Argument 'sqlType' with type .* must be an instance of `pyspark\.sql\.types\.DataType`", + ): + dg.DataAnalyzer._generatorDefaultAttributesFromType("not-a-sql-type") diff --git a/tests/test_quick_tests.py b/tests/test_quick_tests.py index 1a04d4bb..2e998edf 100644 --- a/tests/test_quick_tests.py +++ b/tests/test_quick_tests.py @@ -464,6 +464,42 @@ def test_empty_range(self): empty_range = NRange() assert empty_range.isEmpty() + def test_nrange_legacy_min_and_minvalue_conflict(self): + """Ensure conflicting legacy 'min' and 'minValue' arguments raise a clear error.""" + with pytest.raises(ValueError, match="Only one of 'minValue' and legacy 'min' may be specified"): + NRange(minValue=0.0, min=1.0) + + def test_nrange_legacy_min_must_be_numeric(self): + """Ensure legacy 'min' argument must be numeric.""" + with pytest.raises(ValueError, match=r"Legacy 'min' argument must be an integer or float\."): + NRange(min="not-a-number") + + def test_nrange_unexpected_kwargs_error_message(self): + """Ensure unexpected keyword arguments produce a helpful error.""" + with pytest.raises(ValueError, match=r"Unexpected keyword arguments for NRange: .*"): + NRange(foo=1) + + def test_nrange_maxvalue_and_until_conflict(self): + """Ensure conflicting 'maxValue' and 'until' arguments raise a clear error.""" + with pytest.raises(ValueError, match="Only one of 'maxValue' or 'until' may be specified."): + NRange(maxValue=10, until=20) + + def test_nrange_discrete_range_requires_min_max_step(self): + """Ensure getDiscreteRange validates required attributes.""" + rng = NRange(minValue=0.0, maxValue=10.0) + with pytest.raises( + ValueError, match="Range must have 'minValue', 'maxValue', and 'step' defined\\." + ): + _ = rng.getDiscreteRange() + + def test_nrange_discrete_range_step_must_be_non_zero(self): + """Ensure getDiscreteRange validates non-zero step.""" + rng = NRange(minValue=0.0, maxValue=10.0, step=0) + with pytest.raises( + ValueError, match="Parameter 'step' must be non-zero when computing discrete range\\." + ): + _ = rng.getDiscreteRange() + def test_reversed_ranges(self): testDataSpec = ( dg.DataGenerator(sparkSession=spark, name="ranged_data", rows=100000, partitions=4) diff --git a/tests/test_ranged_values_and_dates.py b/tests/test_ranged_values_and_dates.py index e6453895..fb40d1c6 100644 --- a/tests/test_ranged_values_and_dates.py +++ b/tests/test_ranged_values_and_dates.py @@ -4,6 +4,7 @@ import pyspark.sql.functions as F from pyspark.sql.types import DoubleType, ShortType, LongType, DecimalType, ByteType, DateType from pyspark.sql.types import IntegerType, StringType, FloatType, TimestampType +import pytest import dbldatagen as dg from dbldatagen import DateRange @@ -1033,3 +1034,21 @@ def test_ranged_data_string5(self): s1_expected_values = [f"testing {x:05} >>" for x in [1.5, 1.8, 2.1, 2.4]] s1_values = [r[0] for r in results.select("s1").distinct().collect()] self.assertSetEqual(set(s1_expected_values), set(s1_values)) + + +def test_daterange_parse_interval_requires_value(): + """Validate DateRange.parseInterval requires a non-null interval string.""" + with pytest.raises(ValueError, match="Parameter 'interval_str' must be specified"): + DateRange.parseInterval(None) + + +def test_daterange_compute_date_range_unique_values_positive(): + """Validate DateRange.computeDateRange enforces positive unique_values.""" + with pytest.raises(ValueError, match="Parameter 'unique_values' must be a positive integer"): + DateRange.computeDateRange(begin=None, end=None, interval="days=1", unique_values=0) + + +def test_daterange_compute_timestamp_range_unique_values_positive(): + """Validate DateRange.computeTimestampRange enforces positive unique_values.""" + with pytest.raises(ValueError, match="Parameter 'unique_values' must be a positive integer"): + DateRange.computeTimestampRange(begin=None, end=None, interval="days=1", unique_values=-5) diff --git a/tests/test_utils.py b/tests/test_utils.py index 9c89c0e8..d28b606e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -219,3 +219,9 @@ def test_json_value_from_path(self, path, jsonData, defaultValue, expectedValue) def test_system_time_millis(self): curr_time = system_time_millis() assert curr_time > 0 + + def test_topological_sort_cycle_error_message(self): + """Validate that topologicalSort raises a helpful error message for cyclic dependencies.""" + deps = [("a", {"b"}), ("b", {"a"})] + with pytest.raises(ValueError, match="cyclic or missing dependency detected"): + topologicalSort(deps) From e3ee9e92addea3288963a336046ce73d599b1e66 Mon Sep 17 00:00:00 2001 From: Greg Hansen Date: Mon, 8 Dec 2025 15:36:08 -0500 Subject: [PATCH 3/8] Format tests --- tests/test_distributions.py | 8 ++------ tests/test_generation_from_data.py | 4 +--- tests/test_quick_tests.py | 8 ++------ 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 2197a15c..faa03db2 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -490,15 +490,11 @@ def test_exponential_generation_func(self): def test_exponential_requires_rate_for_scale(self): """Ensure accessing scale without a rate produces a clear error.""" exp = dist.Exponential() - with pytest.raises( - ValueError, match="Cannot compute value for 'scale'; Missing value for 'rate'" - ): + with pytest.raises(ValueError, match="Cannot compute value for 'scale'; Missing value for 'rate'"): _ = exp.scale def test_exponential_requires_rate_for_generation(self): """Ensure generating samples without a rate produces a clear error.""" exp = dist.Exponential() - with pytest.raises( - ValueError, match="Cannot compute value for 'scale'; Missing value for 'rate'" - ): + with pytest.raises(ValueError, match="Cannot compute value for 'scale'; Missing value for 'rate'"): _ = exp.generateNormalizedDistributionSample() diff --git a/tests/test_generation_from_data.py b/tests/test_generation_from_data.py index 717eb083..2b71f84a 100644 --- a/tests/test_generation_from_data.py +++ b/tests/test_generation_from_data.py @@ -122,9 +122,7 @@ def test_df_containing_summary(self): def test_data_analyzer_requires_dataframe(self): """Validate that DataAnalyzer cannot be initialized without a DataFrame.""" - with pytest.raises( - ValueError, match="Argument `df` must be supplied when initializing a `DataAnalyzer`" - ): + with pytest.raises(ValueError, match="Argument `df` must be supplied when initializing a `DataAnalyzer`"): dg.DataAnalyzer() def test_add_measure_to_summary_requires_dataframe(self): diff --git a/tests/test_quick_tests.py b/tests/test_quick_tests.py index 2e998edf..9ba0c0fb 100644 --- a/tests/test_quick_tests.py +++ b/tests/test_quick_tests.py @@ -487,17 +487,13 @@ def test_nrange_maxvalue_and_until_conflict(self): def test_nrange_discrete_range_requires_min_max_step(self): """Ensure getDiscreteRange validates required attributes.""" rng = NRange(minValue=0.0, maxValue=10.0) - with pytest.raises( - ValueError, match="Range must have 'minValue', 'maxValue', and 'step' defined\\." - ): + with pytest.raises(ValueError, match="Range must have 'minValue', 'maxValue', and 'step' defined\\."): _ = rng.getDiscreteRange() def test_nrange_discrete_range_step_must_be_non_zero(self): """Ensure getDiscreteRange validates non-zero step.""" rng = NRange(minValue=0.0, maxValue=10.0, step=0) - with pytest.raises( - ValueError, match="Parameter 'step' must be non-zero when computing discrete range\\." - ): + with pytest.raises(ValueError, match="Parameter 'step' must be non-zero when computing discrete range\\."): _ = rng.getDiscreteRange() def test_reversed_ranges(self): From a628119584f41c751171c0afa09474447fb94121 Mon Sep 17 00:00:00 2001 From: Greg Hansen Date: Mon, 8 Dec 2025 15:51:27 -0500 Subject: [PATCH 4/8] Refactor --- dbldatagen/multi_table_builder.py | 276 ------------------------------ dbldatagen/relation.py | 33 ---- tests/test_multi_table.py | 146 ---------------- 3 files changed, 455 deletions(-) delete mode 100644 dbldatagen/multi_table_builder.py delete mode 100644 dbldatagen/relation.py delete mode 100644 tests/test_multi_table.py diff --git a/dbldatagen/multi_table_builder.py b/dbldatagen/multi_table_builder.py deleted file mode 100644 index fb7f693e..00000000 --- a/dbldatagen/multi_table_builder.py +++ /dev/null @@ -1,276 +0,0 @@ -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -This module defines the ``MultiTableBuilder`` class used for managing relational datasets. -""" - -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from dataclasses import dataclass - -from pyspark.sql import DataFrame - -from dbldatagen.data_generator import DataGenerator -from dbldatagen.datagen_types import ColumnLike -from dbldatagen.relation import ForeignKeyRelation -from dbldatagen.utils import DataGenError, ensure_column - - -class MultiTableBuilder: - """ - Basic builder for managing multiple related datasets backed by ``DataGenerator`` instances. - - This initial implementation focuses on tracking datasets, static DataFrames, and foreign key relations. - Related tables must share a single ``DataGenerator`` so the rows can be generated in a single pass. - """ - - def __init__(self) -> None: - self._datasets: dict[str, _DatasetDefinition] = {} - self._data_generators: list[DataGenerator] = [] - self._static_dataframes: list[DataFrame] = [] - self._foreign_key_relations: list[ForeignKeyRelation] = [] - self._generator_cache: dict[int, DataFrame] = {} - - @property - def data_generators(self) -> list[DataGenerator]: - """ - List of unique ``DataGenerator`` instances tracked by the builder. - """ - return list(dict.fromkeys(self._data_generators)) - - @property - def static_dataframes(self) -> list[DataFrame]: - """ - List of static ``DataFrame`` objects tracked by the builder. - """ - return list(self._static_dataframes) - - @property - def foreign_key_relations(self) -> list[ForeignKeyRelation]: - """ - List of registered :class:`ForeignKeyRelation` objects. - """ - return list(self._foreign_key_relations) - - def add_data_generator( - self, - name: str, - generator: DataGenerator, - columns: Sequence[ColumnLike] | None = None, - ) -> None: - """ - Register a dataset backed by a ``DataGenerator``. - - :param name: Dataset name - :param generator: Generator instance capable of producing all required columns - :param columns: Default column projection for the dataset - """ - if name in self._datasets: - raise DataGenError(f"Dataset '{name}' is already defined.") - - self._datasets[name] = _DatasetDefinition( - name=name, - generator=generator, - columns=tuple(columns) if columns is not None else None, - ) - self._data_generators.append(generator) - - def add_static_dataframe( - self, - name: str, - dataframe: DataFrame, - columns: Sequence[ColumnLike] | None = None, - ) -> None: - """ - Register a dataset backed by a pre-built ``DataFrame``. - - :param name: Dataset name - :param dataframe: Static ``DataFrame`` instance - :param columns: Default column projection for the dataset - """ - if name in self._datasets: - raise DataGenError(f"Dataset '{name}' is already defined.") - - self._datasets[name] = _DatasetDefinition( - name=name, - dataframe=dataframe, - columns=tuple(columns) if columns is not None else None, - ) - self._static_dataframes.append(dataframe) - - def add_foreign_key_relation( - self, - relation: ForeignKeyRelation | None = None, - *, - from_table: str | None = None, - from_column: ColumnLike | None = None, - to_table: str | None = None, - to_column: ColumnLike | None = None, - ) -> ForeignKeyRelation: - """ - Register a foreign key relation between two datasets. - - The relation can be provided via a fully constructed ``ForeignKeyRelation`` or via keyword arguments. - - :param relation: Optional ``ForeignKeyRelation`` instance - :param from_table: Referencing dataset name (required if ``relation`` not supplied) - :param from_column: Referencing column (required if ``relation`` not supplied) - :param to_table: Referenced dataset name (required if ``relation`` not supplied) - :param to_column: Referenced column (required if ``relation`` not supplied) - :return: Registered relation - """ - if relation is None: - if not all([from_table, from_column, to_table, to_column]): - raise DataGenError("Foreign key relation requires table and column details.") - relation = ForeignKeyRelation( - from_table=from_table, # type: ignore[arg-type] - from_column=from_column, # type: ignore[arg-type] - to_table=to_table, # type: ignore[arg-type] - to_column=to_column, # type: ignore[arg-type] - ) - - self._validate_dataset_exists(relation.from_table) - self._validate_dataset_exists(relation.to_table) - - self._foreign_key_relations.append(relation) - return relation - - def build( - self, - dataset_names: Sequence[str] | None = None, - column_overrides: Mapping[str, Sequence[ColumnLike]] | None = None, - ) -> dict[str, DataFrame]: - """ - Materialize one or more datasets managed by the builder. - - :param dataset_names: Optional list of dataset names to build (defaults to all datasets) - :param column_overrides: Optional mapping of dataset name to column overrides - :return: Dictionary keyed by dataset name containing Spark ``DataFrame`` objects - """ - targets = dataset_names or list(self._datasets) - results: dict[str, DataFrame] = {} - - for name in targets: - overrides = column_overrides[name] if column_overrides and name in column_overrides else None - results[name] = self.get_dataset(name, columns=overrides) - - return results - - def get_dataset(self, name: str, columns: Sequence[ColumnLike] | None = None) -> DataFrame: - """ - Retrieve a single dataset as a ``DataFrame`` applying optional column overrides. - - :param name: Dataset name - :param columns: Optional select expressions to override defaults - :return: Spark ``DataFrame`` with projected columns - """ - dataset = self._datasets.get(name) - if dataset is None: - raise DataGenError(f"Dataset '{name}' is not defined.") - - if dataset.dataframe is not None: - return dataset.select_columns(dataset.dataframe, columns) - - assert dataset.generator is not None - self._ensure_shared_generator(name) - - base_df = self._get_or_build_generator_output(dataset.generator) - return dataset.select_columns(base_df, columns) - - def clear_cache(self) -> None: - """ - Clear cached ``DataFrame`` results for generator-backed datasets. - """ - self._generator_cache.clear() - - def _validate_dataset_exists(self, name: str) -> None: - if name not in self._datasets: - raise DataGenError(f"Dataset '{name}' is not registered with the builder.") - - def _get_or_build_generator_output(self, generator: DataGenerator) -> DataFrame: - generator_id = id(generator) - if generator_id not in self._generator_cache: - self._generator_cache[generator_id] = generator.build() - return self._generator_cache[generator_id] - - def _ensure_shared_generator(self, name: str) -> None: - """ - Validate that all generator-backed tables within the relation group share the same generator instance. - """ - dataset = self._datasets[name] - generator = dataset.generator - if generator is None: - return - - for related_name in self._collect_related_tables(name): - related_dataset = self._datasets[related_name] - if related_dataset.generator is None: - continue - if related_dataset.generator is not generator: - msg = ( - f"Datasets '{name}' and '{related_name}' participate in a foreign key relation " - "and must share the same DataGenerator instance." - ) - raise DataGenError(msg) - - def _collect_related_tables(self, name: str) -> set[str]: - """ - Collect all tables connected to the supplied table via foreign key relations. - """ - related: set[str] = set() - to_visit = [name] - - while to_visit: - current = to_visit.pop() - for relation in self._foreign_key_relations: - neighbor = self._neighbor_for_relation(current, relation) - if neighbor and neighbor not in related: - related.add(neighbor) - to_visit.append(neighbor) - - return related - - @staticmethod - def _neighbor_for_relation(table: str, relation: ForeignKeyRelation) -> str | None: - if relation.from_table == table: - return relation.to_table - if relation.to_table == table: - return relation.from_table - return None - - -@dataclass -class _DatasetDefinition: - """ - Internal representation of a dataset tracked by a ``MultiTableBuilder``. - """ - - name: str - generator: DataGenerator | None = None - dataframe: DataFrame | None = None - columns: tuple[ColumnLike, ...] | None = None - - def __post_init__(self) -> None: - has_generator = self.generator is not None - has_dataframe = self.dataframe is not None - - if has_generator == has_dataframe: - raise DataGenError(f"Dataset '{self.name}' must specify exactly one of DataGenerator or DataFrame.") - - def select_columns(self, df: DataFrame, overrides: Sequence[ColumnLike] | None = None) -> DataFrame: - """ - Apply column selection for the dataset using overrides when supplied. - - :param df: Source ``DataFrame`` to project - :param overrides: Optional column expressions to use instead of defaults - :return: Projected ``DataFrame`` - """ - select_exprs = overrides if overrides is not None else self.columns - if not select_exprs: - return df - - normalized_columns = [ensure_column(expr) for expr in select_exprs] - return df.select(*normalized_columns) diff --git a/dbldatagen/relation.py b/dbldatagen/relation.py deleted file mode 100644 index 5864976e..00000000 --- a/dbldatagen/relation.py +++ /dev/null @@ -1,33 +0,0 @@ -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -This module defines the ``ForeignKeyRelation`` class used for describing foreign key relations between datasets. -""" - -from dataclasses import dataclass - -from dbldatagen.datagen_types import ColumnLike -from dbldatagen.utils import ensure_column - - -@dataclass(frozen=True) -class ForeignKeyRelation: - """ - Dataclass describing a foreign key relation between two datasets managed by a ``MultiTableBuilder``. - - :param from_table: Name of the referencing table - :param from_column: Referencing column as a string or ``pyspark.sql.Column`` expression - :param to_table: Name of the referenced table - :param to_column: Referenced column as a string or ``pyspark.sql.Column`` expression - """ - - from_table: str - from_column: ColumnLike - to_table: str - to_column: ColumnLike - - def __post_init__(self) -> None: - object.__setattr__(self, "from_column", ensure_column(self.from_column)) - object.__setattr__(self, "to_column", ensure_column(self.to_column)) diff --git a/tests/test_multi_table.py b/tests/test_multi_table.py deleted file mode 100644 index 80d58c07..00000000 --- a/tests/test_multi_table.py +++ /dev/null @@ -1,146 +0,0 @@ -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import pytest -from pyspark.sql.functions import col - -import dbldatagen as dg -from dbldatagen.multi_table_builder import MultiTableBuilder -from dbldatagen.relation import ForeignKeyRelation -from dbldatagen.utils import DataGenError - - -spark = dg.SparkSingleton.getLocalInstance("unit tests") - - -def _build_generator(name: str, column_names: list[str], rows: int = 10) -> dg.DataGenerator: - """ - Helper to create a ``DataGenerator`` with deterministic integer columns. - """ - generator = dg.DataGenerator(sparkSession=spark, name=name, rows=rows, partitions=1) - - for index, column_name in enumerate(column_names): - min_value = index * 100 - max_value = min_value + rows - 1 - generator = generator.withColumn(column_name, "int", minValue=min_value, maxValue=max_value) - - return generator - - -class TestMultiTableBuilder: - def test_single_data_generator_builds_dataset(self) -> None: - builder = MultiTableBuilder() - generator = _build_generator("single_gen", ["order_id", "order_value"], rows=12) - - builder.add_data_generator( - name="orders", - generator=generator, - columns=[col("order_id").alias("id"), "order_value"], - ) - - results = builder.build() - - assert set(results) == {"orders"} - orders_df = results["orders"] - assert orders_df.count() == 12 - assert orders_df.columns == ["id", "order_value"] - - def test_independent_data_generators_build_individually(self) -> None: - builder = MultiTableBuilder() - - generator_a = _build_generator("generator_a", ["a_id", "a_value"], rows=5) - generator_b = _build_generator("generator_b", ["b_id", "b_value"], rows=7) - - builder.add_data_generator("table_a", generator_a, columns=["a_id", "a_value"]) - builder.add_data_generator("table_b", generator_b, columns=["b_id", "b_value"]) - - results = builder.build() - - assert set(results.keys()) == {"table_a", "table_b"} - assert results["table_a"].count() == 5 - assert results["table_b"].count() == 7 - assert len(builder.data_generators) == 2 - - def test_foreign_key_relation_requires_shared_generator(self) -> None: - builder = MultiTableBuilder() - - parent_generator = _build_generator("parent_gen", ["parent_id", "parent_value"], rows=6) - child_generator = _build_generator("child_gen", ["child_id", "child_parent_id"], rows=6) - - builder.add_data_generator("parents", parent_generator, columns=["parent_id", "parent_value"]) - builder.add_data_generator( - "children", - child_generator, - columns=["child_id", "child_parent_id"], - ) - - builder.add_foreign_key_relation( - ForeignKeyRelation( - from_table="children", - from_column="child_parent_id", - to_table="parents", - to_column="parent_id", - ) - ) - - with pytest.raises(DataGenError): - builder.build(["children"]) - - def test_partial_relation_with_mismatched_generators_raises(self) -> None: - builder = MultiTableBuilder() - - orders_generator = _build_generator("orders_gen", ["order_id", "order_value"], rows=6) - line_items_generator = _build_generator( - "line_items_gen", - ["line_item_id", "order_id"], - rows=8, - ) - shipments_generator = _build_generator("shipments_gen", ["shipment_id"], rows=4) - - builder.add_data_generator("orders", orders_generator, columns=["order_id", "order_value"]) - builder.add_data_generator("line_items", line_items_generator, columns=["line_item_id", "order_id"]) - builder.add_data_generator("shipments", shipments_generator, columns=["shipment_id"]) - - builder.add_foreign_key_relation( - ForeignKeyRelation( - from_table="line_items", - from_column="order_id", - to_table="orders", - to_column="order_id", - ) - ) - - with pytest.raises(DataGenError): - builder.build() - - def test_transitive_relation_requires_shared_generator(self) -> None: - builder = MultiTableBuilder() - - generator_a = _build_generator("gen_a", ["a_id", "b_id"], rows=5) - generator_b = _build_generator("gen_b", ["b_id", "c_id"], rows=5) - generator_c = _build_generator("gen_c", ["c_id"], rows=5) - - builder.add_data_generator("table_a", generator_a, columns=["a_id", "b_id"]) - builder.add_data_generator("table_b", generator_b, columns=["b_id", "c_id"]) - builder.add_data_generator("table_c", generator_c, columns=["c_id"]) - - builder.add_foreign_key_relation( - ForeignKeyRelation( - from_table="table_a", - from_column="b_id", - to_table="table_b", - to_column="b_id", - ) - ) - builder.add_foreign_key_relation( - ForeignKeyRelation( - from_table="table_b", - from_column="c_id", - to_table="table_c", - to_column="c_id", - ) - ) - - with pytest.raises(DataGenError): - builder.get_dataset("table_a") From ea8790be37475280de5be22f5511fc742bf7d455 Mon Sep 17 00:00:00 2001 From: Greg Hansen Date: Mon, 8 Dec 2025 15:57:43 -0500 Subject: [PATCH 5/8] Refactor --- dbldatagen/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/dbldatagen/__init__.py b/dbldatagen/__init__.py index 76aeb5c9..eaedeec0 100644 --- a/dbldatagen/__init__.py +++ b/dbldatagen/__init__.py @@ -74,8 +74,6 @@ from .html_utils import HtmlUtils from .datasets_object import Datasets from .config import OutputDataset -from .multi_table_builder import MultiTableBuilder -from .relation import ForeignKeyRelation from .datagen_types import ColumnLike __all__ = [ @@ -96,8 +94,6 @@ "datasets_object", "constraints", "config", - "multi_table_builder", - "relation", "datagen_types", ] From 5640a04862a514cf46ed83431c1427b7a5265af6 Mon Sep 17 00:00:00 2001 From: Greg Hansen Date: Mon, 8 Dec 2025 16:31:46 -0500 Subject: [PATCH 6/8] Improve test coverage --- tests/test_constraints.py | 4 ++++ tests/test_quick_tests.py | 32 ++++++++++++++++++++++++-------- tests/test_utils.py | 9 +++++++++ 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 0f88b00e..c993145b 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -60,6 +60,10 @@ def test_simple_constraints2(self, generationSpec1): rowCount = testDataDF.count() assert rowCount == 100 + def test_sql_expr_requires_non_empty_expression(self): + with pytest.raises(ValueError, match="Expression must be a valid non-empty SQL string"): + SqlExpr("") + def test_multiple_constraints(self, generationSpec1): testDataSpec = generationSpec1.withConstraints([SqlExpr("id < 100"), SqlExpr("id > 0")]) diff --git a/tests/test_quick_tests.py b/tests/test_quick_tests.py index 9ba0c0fb..bde2962c 100644 --- a/tests/test_quick_tests.py +++ b/tests/test_quick_tests.py @@ -465,37 +465,55 @@ def test_empty_range(self): assert empty_range.isEmpty() def test_nrange_legacy_min_and_minvalue_conflict(self): - """Ensure conflicting legacy 'min' and 'minValue' arguments raise a clear error.""" with pytest.raises(ValueError, match="Only one of 'minValue' and legacy 'min' may be specified"): NRange(minValue=0.0, min=1.0) def test_nrange_legacy_min_must_be_numeric(self): - """Ensure legacy 'min' argument must be numeric.""" with pytest.raises(ValueError, match=r"Legacy 'min' argument must be an integer or float\."): NRange(min="not-a-number") + def test_nrange_legacy_max_and_maxvalue_conflict(self): + with pytest.raises(ValueError, match="Only one of 'maxValue' and legacy 'max' may be specified"): + NRange(maxValue=10.0, max=11.0) + + def test_nrange_legacy_max_must_be_numeric(self): + with pytest.raises(ValueError, match=r"Legacy 'max' argument must be an integer or float\."): + NRange(max="not-a-number") + def test_nrange_unexpected_kwargs_error_message(self): - """Ensure unexpected keyword arguments produce a helpful error.""" with pytest.raises(ValueError, match=r"Unexpected keyword arguments for NRange: .*"): NRange(foo=1) def test_nrange_maxvalue_and_until_conflict(self): - """Ensure conflicting 'maxValue' and 'until' arguments raise a clear error.""" with pytest.raises(ValueError, match="Only one of 'maxValue' or 'until' may be specified."): NRange(maxValue=10, until=20) def test_nrange_discrete_range_requires_min_max_step(self): - """Ensure getDiscreteRange validates required attributes.""" rng = NRange(minValue=0.0, maxValue=10.0) with pytest.raises(ValueError, match="Range must have 'minValue', 'maxValue', and 'step' defined\\."): _ = rng.getDiscreteRange() def test_nrange_discrete_range_step_must_be_non_zero(self): - """Ensure getDiscreteRange validates non-zero step.""" rng = NRange(minValue=0.0, maxValue=10.0, step=0) with pytest.raises(ValueError, match="Parameter 'step' must be non-zero when computing discrete range\\."): _ = rng.getDiscreteRange() + def test_nrange_adjust_for_byte_type_maxvalue_out_of_range(self): + rng = NRange(maxValue=300) # above allowed ByteType max of 256 + with pytest.raises( + ValueError, + match=r"`maxValue` must be within the valid range \(0 - 256\) for ByteType\.", + ): + rng.adjustForColumnDatatype(ByteType()) + + def test_nrange_get_continuous_range_requires_min_and_max(self): + rng = NRange(minValue=None, maxValue=10.0) + with pytest.raises( + ValueError, + match=r"Range must have 'minValue' and 'maxValue' defined\.", + ): + _ = rng.getContinuousRange() + def test_reversed_ranges(self): testDataSpec = ( dg.DataGenerator(sparkSession=spark, name="ranged_data", rows=100000, partitions=4) @@ -561,8 +579,6 @@ def test_date_time_ranges(self): rowCount = rangedDF.count() assert rowCount == 100000 - # TODO: add additional validation statement - @pytest.mark.parametrize("asHtml", [True, False]) def test_script_table(self, asHtml): testDataSpec = ( diff --git a/tests/test_utils.py b/tests/test_utils.py index d28b606e..352a93eb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,8 @@ from datetime import timedelta import pytest +from pyspark.sql import Column +from pyspark.sql.functions import current_date from dbldatagen import ( ensure, @@ -18,6 +20,8 @@ system_time_millis, ) +from dbldatagen.utils import ensure_column + spark = SparkSingleton.getLocalInstance("unit tests") @@ -225,3 +229,8 @@ def test_topological_sort_cycle_error_message(self): deps = [("a", {"b"}), ("b", {"a"})] with pytest.raises(ValueError, match="cyclic or missing dependency detected"): topologicalSort(deps) + + @pytest.mark.parametrize("col", ["col1", current_date()]) + def test_ensure_column(self, col): + ensured = ensure_column(col) + assert isinstance(ensured, Column) From 4baea987cc14c156c78db6beef1ab6a369a272b3 Mon Sep 17 00:00:00 2001 From: Greg Hansen Date: Wed, 10 Dec 2025 09:12:10 -0500 Subject: [PATCH 7/8] Fix attribute handling for Constraint --- CHANGELOG.md | 4 +++- dbldatagen/constraints/constraint.py | 5 ++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dc8868b7..650212d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,11 +6,12 @@ All notable changes to the Databricks Labs Data Generator will be documented in ### unreleased #### Fixed -* Updated build scripts to use Ubuntu 22.04 to correspond to environment in Databricks runtime * Refactored `DataAnalyzer` and `BasicStockTickerProvider` to comply with ANSI SQL standards * Removed internal modification of `SparkSession` +* Updated `Constraint` to treat `_filterExpression` and `_calculatedFilterExpression` as instance variables #### Changed +* Added type hints for modules and classes * Changed base Databricks runtime version to DBR 13.3 LTS (based on Apache Spark 3.4.1) - minimum supported version of Python is now 3.10.12 * Updated build tooling to use [hatch](https://hatch.pypa.io/latest/) @@ -23,6 +24,7 @@ All notable changes to the Databricks Labs Data Generator will be documented in #### Added * Added support for serialization to/from JSON format * Added Ruff and mypy tooling +* Added `OutputDataset` class and the ability to save a `DataGenerator` to an output table or files ### Version 0.4.0 Hotfix 2 diff --git a/dbldatagen/constraints/constraint.py b/dbldatagen/constraints/constraint.py index e3cb8781..59d130b3 100644 --- a/dbldatagen/constraints/constraint.py +++ b/dbldatagen/constraints/constraint.py @@ -28,12 +28,11 @@ class Constraint(SerializableToDict, ABC): """ SUPPORTED_OPERATORS: ClassVar[list[str]] = ["<", ">", ">=", "!=", "==", "=", "<=", "<>"] - _filterExpression: Column | None = None - _calculatedFilterExpression: bool = False - _supportsStreaming: bool = False def __init__(self, supportsStreaming: bool = False) -> None: self._supportsStreaming = supportsStreaming + self._filterExpression: Column | None = None + self._calculatedFilterExpression: bool = False @staticmethod def _columnsFromListOrString( From 3e1f5a2a07718a865a9d019a64d1bcee734c5cf8 Mon Sep 17 00:00:00 2001 From: Greg Hansen Date: Wed, 10 Dec 2025 09:17:37 -0500 Subject: [PATCH 8/8] Reformat changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 650212d3..6014ba19 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,8 @@ All notable changes to the Databricks Labs Data Generator will be documented in #### Fixed * Refactored `DataAnalyzer` and `BasicStockTickerProvider` to comply with ANSI SQL standards +* Refactored `Constraint` to treat `_filterExpression` and `_calculatedFilterExpression` as instance variables * Removed internal modification of `SparkSession` -* Updated `Constraint` to treat `_filterExpression` and `_calculatedFilterExpression` as instance variables #### Changed * Added type hints for modules and classes