Skip to content

Commit 3eaf7f1

Browse files
committed
Add serialization for DataGenerators and ColumnGenerationSpecs
1 parent 0d0f4c2 commit 3eaf7f1

File tree

9 files changed

+208
-7
lines changed

9 files changed

+208
-7
lines changed

Pipfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pandas = "==1.2.4"
2626
setuptools = "==65.6.3"
2727
pyparsing = "==2.4.7"
2828
jmespath = "==0.10.0"
29+
pyyaml = ">=6.0.2"
2930

3031
[requires]
3132
python_version = "3.8.12"

dbldatagen/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from ._version import __version__
3535
from .column_generation_spec import ColumnGenerationSpec
3636
from .column_spec_options import ColumnSpecOptions
37+
from .constraints import Constraint, ChainedRelation, LiteralRange, LiteralRelation, NegativeValues, PositiveValues, \
38+
RangedValues, SqlExpr, UniqueCombinations
3739
from .data_analyzer import DataAnalyzer
3840
from .schema_parser import SchemaParser
3941
from .daterange import DateRange
@@ -49,7 +51,7 @@
4951
__all__ = ["data_generator", "data_analyzer", "schema_parser", "daterange", "nrange",
5052
"column_generation_spec", "utils", "function_builder",
5153
"spark_singleton", "text_generators", "datarange", "datagen_constants",
52-
"text_generator_plugins", "html_utils", "datasets_object"
54+
"text_generator_plugins", "html_utils", "datasets_object", "constraints"
5355
]
5456

5557

dbldatagen/column_generation_spec.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .daterange import DateRange
2626
from .distributions import Normal, DataDistribution
2727
from .nrange import NRange
28+
from .serialization import Serializable
2829
from .text_generators import TemplateGenerator
2930
from .utils import ensure, coalesce_values
3031
from .schema_parser import SchemaParser
@@ -40,7 +41,7 @@
4041
RAW_VALUES_COMPUTE_METHOD]
4142

4243

43-
class ColumnGenerationSpec(object):
44+
class ColumnGenerationSpec(Serializable):
4445
""" Column generation spec object - specifies how column is to be generated
4546
4647
Each column to be output will have a corresponding ColumnGenerationSpec object.
@@ -119,7 +120,7 @@ def __init__(self, name, colType=None, minValue=0, maxValue=None, step=1, prefix
119120
if EXPR_OPTION not in kwargs:
120121
raise ValueError("Column generation spec must have `expr` attribute specified if datatype is inferred")
121122

122-
elif type(colType) == str:
123+
elif isinstance(colType, str):
123124
colType = SchemaParser.columnTypeFromString(colType)
124125

125126
assert isinstance(colType, DataType), f"colType `{colType}` is not instance of DataType"
@@ -299,6 +300,29 @@ def __init__(self, name, colType=None, minValue=0, maxValue=None, step=1, prefix
299300
# set up the temporary columns needed for data generation
300301
self._setupTemporaryColumns()
301302

303+
@classmethod
304+
def getMapping(cls):
305+
return {
306+
"colName": "name",
307+
"colType": "typeString",
308+
"minValue": "min",
309+
"maxValue": "max",
310+
"step": "step",
311+
"prefix": "prefix",
312+
"random": "random",
313+
"randomSeed": "_randomSeed",
314+
"randomSeedMethod": "_randomSeedMethod",
315+
"implicit": "implicit",
316+
"omit": "omit",
317+
"nullable": "nullable",
318+
"values": "values",
319+
"weights": "weights",
320+
"distribution": "distribution",
321+
"baseColumn": "baseColumn",
322+
"dataRange": "dataRange"
323+
# TODO: ADD ALL COLUMN SPEC OPTIONS?
324+
}
325+
302326
def _temporaryRename(self, tmpName):
303327
""" Create enter / exit object to support temporary renaming of column spec
304328
@@ -417,6 +441,11 @@ def inferDatatype(self):
417441
"""
418442
return self._inferDataType
419443

444+
@property
445+
def typeString(self):
446+
""" Get the simple string representing the column type."""
447+
return self.datatype.simpleString()
448+
420449
@property
421450
def baseColumns(self):
422451
""" Return base columns as list of strings"""
@@ -836,6 +865,10 @@ def numFeatures(self):
836865
"""
837866
return self['numFeatures']
838867

868+
@property
869+
def dataRange(self):
870+
return self._dataRange
871+
839872
def structType(self):
840873
"""get the `structType` attribute used to generate values for this column
841874

dbldatagen/data_generator.py

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,25 @@
66
This file defines the `DataGenError` and `DataGenerator` classes
77
"""
88
import copy
9+
import json
910
import logging
1011
import re
1112

13+
import yaml
1214
from pyspark.sql.types import LongType, IntegerType, StringType, StructType, StructField, DataType
1315

1416
from ._version import _get_spark_version
1517
from .column_generation_spec import ColumnGenerationSpec
16-
from .constraints.constraint import Constraint
17-
from .constraints.sql_expr import SqlExpr
18+
from .constraints import Constraint, SqlExpr
19+
from .datarange import DataRange
20+
from .distributions import DataDistribution
21+
1822
from .datagen_constants import DEFAULT_RANDOM_SEED, RANDOM_SEED_FIXED, RANDOM_SEED_HASH_FIELD_NAME, \
1923
DEFAULT_SEED_COLUMN, SPARK_RANGE_COLUMN, MIN_SPARK_VERSION, \
2024
OPTION_RANDOM, OPTION_RANDOM_SEED, OPTION_RANDOM_SEED_METHOD, \
2125
INFER_DATATYPE, SPARK_DEFAULT_PARALLELISM
2226
from .html_utils import HtmlUtils
27+
from .serialization import Serializable
2328
from .schema_parser import SchemaParser
2429
from .spark_singleton import SparkSingleton
2530
from .utils import ensure, topologicalSort, DataGenError, deprecated, split_list_matching_condition
@@ -30,7 +35,7 @@
3035
_STREAMING_TIMESTAMP_COLUMN = "_source_timestamp"
3136

3237

33-
class DataGenerator:
38+
class DataGenerator(Serializable):
3439
""" Main Class for test data set generation
3540
3641
This class acts as the entry point to all test data generation activities.
@@ -173,6 +178,50 @@ def __init__(self, sparkSession=None, name=None, randomSeedMethod=None,
173178
# set up use of pandas udfs
174179
self._setupPandas(batchSize)
175180

181+
@classmethod
182+
def getMapping(cls):
183+
return {
184+
"name": "name",
185+
"randomSeedMethod": "_seedMethod",
186+
"rows": "_rowCount",
187+
"startingId": "starting_id",
188+
"randomSeed": "_randomSeed",
189+
"partitions": "partitions",
190+
"verbose": "verbose",
191+
"batchSize": "_batchSize",
192+
"debug": "debug",
193+
"seedColumnName": "_seedColumnName",
194+
"random": "_defaultRandom"
195+
}
196+
197+
@classmethod
198+
def fromDict(cls, options):
199+
""" Creates a DataGenerator instance from a Python dictionary.
200+
:param options: Python dictionary of options for the DataGenerator, ColumnGenerationSpecs, and Constraints
201+
:return: DataGenerator instance
202+
"""
203+
ir = options.copy()
204+
columns = ir.pop("columns") if "columns" in ir else []
205+
constraints = ir.pop("constraints") if "constraints" in ir else []
206+
return (
207+
DataGenerator(**{k: v for k, v in ir.items() if not isinstance(v, list)})
208+
.withColumnDefinitions(columns)
209+
.withConstraintDefinitions(constraints)
210+
)
211+
212+
def toDict(self):
213+
""" Creates a Python dictionary from a DataGenerator instance.
214+
:return: Python dictionary of options for the DataGenerator, ColumnGenerationSpecs, and Constraints
215+
"""
216+
d = {constructor_key: getattr(self, object_key) for constructor_key, object_key in self.getMapping().items()}
217+
d["columns"] = [{
218+
k: v for k, v in column.toDict().items()
219+
if k != "kind"}
220+
for column in self.getColumnGenerationSpecs()]
221+
d["constraints"] = [constraint.toDict() for constraint in self.getConstraints()]
222+
d["kind"] = self.__class__.__name__
223+
return d
224+
176225
@property
177226
def seedColumnName(self):
178227
""" return the name of data generation seed column"""
@@ -869,6 +918,26 @@ def withColumn(self, colName, colType=StringType(), minValue=None, maxValue=None
869918
self._inferredSchemaFields.append(StructField(colName, newColumn.datatype, nullable))
870919
return self
871920

921+
def withColumnDefinitions(self, columns):
922+
""" Adds a set of columns to the synthetic generation specification.
923+
:param columns: A list of column generation specifications as dictionaries
924+
:returns: A modified in-place instance of a data generator allowing for chaining of calls
925+
following a builder pattern
926+
"""
927+
for column in columns:
928+
internal_column = column.copy()
929+
if "colName" not in internal_column:
930+
internal_column["colName"] = internal_column.pop("name")
931+
for k, v in internal_column.items():
932+
if k == "dataRange":
933+
t = [s for s in DataRange.__subclasses__() if s.__name__ == v["kind"]][0]
934+
internal_column[k] = t.fromDict(v)
935+
if k == "distribution":
936+
t = [s for s in DataDistribution.__subclasses__() if s.__name__ == v["kind"]][0]
937+
internal_column[k] = t.fromDict(v)
938+
self.withColumn(**internal_column)
939+
return self
940+
872941
def _mkSqlStructFromList(self, fields):
873942
"""
874943
Create a SQL struct expression from a list of fields
@@ -1206,6 +1275,12 @@ def _getColumnDataTypes(self, columns):
12061275
"""
12071276
return [self._columnSpecsByName[colspec].datatype for colspec in columns]
12081277

1278+
def getColumnGenerationSpecs(self):
1279+
return self._allColumnSpecs
1280+
1281+
def getConstraints(self):
1282+
return self._constraints
1283+
12091284
def withConstraint(self, constraint):
12101285
"""Add a constraint to control the data generation
12111286
@@ -1255,6 +1330,18 @@ def withSqlConstraint(self, sqlExpression: str):
12551330
self.withConstraint(SqlExpr(sqlExpression))
12561331
return self
12571332

1333+
def withConstraintDefinitions(self, constraints):
1334+
""" Adds a set of constraints to the synthetic generation specification.
1335+
1336+
:param constraints: A list of constraints as dictionaries
1337+
:returns: A modified in-place instance of a data generator allowing for chaining of calls
1338+
following a builder pattern
1339+
"""
1340+
for c in constraints:
1341+
t = [s for s in Constraint.__subclasses__() if s.__name__ == c["kind"]][0]
1342+
self.withConstraint(t.fromDict(c)) # Call fromDict
1343+
return self
1344+
12581345
def computeBuildPlan(self):
12591346
""" prepare for building by computing a pseudo build plan
12601347
@@ -1604,3 +1691,33 @@ def scriptMerge(self, tgtName=None, srcName=None, updateExpr=None, delExpr=None,
16041691
result = HtmlUtils.formatCodeAsHtml(results)
16051692

16061693
return result
1694+
1695+
@staticmethod
1696+
def fromJson(options):
1697+
""" Creates a data generator from a JSON string.
1698+
:param options: A JSON string containing data generation options
1699+
:return: A data generator with the specified options
1700+
"""
1701+
options = json.loads(options)
1702+
return DataGenerator.fromDict(options)
1703+
1704+
def toJson(self):
1705+
""" Returns the JSON string representation of a data generator.
1706+
:return: A JSON string representation of the DataGenerator
1707+
"""
1708+
return json.dumps(self.toDict())
1709+
1710+
@staticmethod
1711+
def fromYaml(options):
1712+
""" Creates a data generator from a YAML string.
1713+
:param options: A YAML string containing data generation options
1714+
:return: A data generator with the specified options
1715+
"""
1716+
options = yaml.safe_load(options)
1717+
return DataGenerator.fromDict(options)
1718+
1719+
def toYaml(self):
1720+
""" Returns the YAML string representation of a data generator.
1721+
:return: A YAML string representation of the DataGenerator
1722+
"""
1723+
return yaml.dump(self.toDict())

dbldatagen/datarange.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,16 @@
1010
1111
"""
1212

13+
from .serialization import Serializable
1314

14-
class DataRange(object):
15+
16+
class DataRange(Serializable):
1517
""" Abstract class used as base class for NRange and DateRange """
1618

19+
@classmethod
20+
def getMapping(cls):
21+
raise NotImplementedError("method not implemented")
22+
1723
def isEmpty(self):
1824
"""Check if object is empty (i.e all instance vars of note are `None`)"""
1925
raise NotImplementedError("method not implemented")

dbldatagen/daterange.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(self, begin, end, interval=None, datetime_format=DEFAULT_UTC_TS_FOR
4444
assert begin is not None, "`begin` must be specified"
4545
assert end is not None, "`end` must be specified"
4646

47+
self.datetime_format = datetime_format
4748
self.begin = begin if not isinstance(begin, str) else self._datetime_from_string(begin, datetime_format)
4849
self.end = end if not isinstance(end, str) else self._datetime_from_string(end, datetime_format)
4950
self.interval = interval if not isinstance(interval, str) else self._timedelta_from_string(interval)
@@ -54,12 +55,37 @@ def __init__(self, begin, end, interval=None, datetime_format=DEFAULT_UTC_TS_FOR
5455
* self.computeTimestampIntervals(self.begin, self.end, self.interval))
5556
self.step = self.interval.total_seconds()
5657

58+
@classmethod
59+
def getMapping(cls):
60+
return {
61+
"begin": "begin_string",
62+
"end": "end_string",
63+
"interval": "interval_string",
64+
"datetime_format": "datetime_format"
65+
}
66+
67+
@property
68+
def begin_string(self):
69+
return self._string_from_datetime(self.begin, self.datetime_format)
70+
71+
@property
72+
def end_string(self):
73+
return self._string_from_datetime(self.end, self.datetime_format)
74+
75+
@property
76+
def interval_string(self):
77+
return self.formatInterval(int(self.interval.total_seconds()))
78+
5779
@classmethod
5880
def _datetime_from_string(cls, date_str, date_format):
5981
"""convert string to Python DateTime object using format"""
6082
result = datetime.strptime(date_str, date_format)
6183
return result
6284

85+
@classmethod
86+
def _string_from_datetime(cls, date_str, date_format):
87+
return datetime.strftime(date_str, date_format)
88+
6389
@classmethod
6490
def _timedelta_from_string(cls, interval):
6591
return cls.parseInterval(interval)
@@ -70,6 +96,11 @@ def parseInterval(cls, interval_str):
7096
assert interval_str is not None, "`interval_str` must be specified"
7197
return parse_time_interval(interval_str)
7298

99+
@classmethod
100+
def formatInterval(cls, interval_time_seconds):
101+
assert interval_time_seconds is not None, "`interval_time` must be specified"
102+
return f"INTERVAL {interval_time_seconds} SECONDS"
103+
73104
@classmethod
74105
def _getDateTime(cls, dt, datetime_format, default_value):
75106
if isinstance(dt, str):

dbldatagen/nrange.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,18 @@ def __init__(self, minValue=None, maxValue=None, step=None, until=None, **kwArgs
5757
assert self.maxValue is None if until is not None else True, "Only one of maxValue or until can be specified"
5858

5959
if until is not None:
60+
self.until = until
6061
self.maxValue = until + 1
6162
self.step = step
6263

64+
@classmethod
65+
def getMapping(cls):
66+
return {
67+
"minValue": "minValue",
68+
"maxValue": "maxValue",
69+
"step": "step"
70+
}
71+
6372
def __str__(self):
6473
return f"NRange({self.minValue}, {self.maxValue}, {self.step})"
6574

python/dev_require.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ python-dateutil==2.8.1
1010
six==1.15.0
1111
pyparsing==2.4.7
1212
jmespath==0.10.0
13+
pyyaml>=6.0.2
1314

1415
# The following packages are required for development only
1516
wheel==0.36.2

python/require.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ python-dateutil==2.8.1
1010
six==1.15.0
1111
pyparsing==2.4.7
1212
jmespath==0.10.0
13+
pyyaml>=6.0.2
1314

1415
# The following packages are required for development only
1516
wheel==0.36.2

0 commit comments

Comments
 (0)