Skip to content

Commit 612618d

Browse files
committed
Add serialization for constraints
1 parent e0526d8 commit 612618d

File tree

9 files changed

+34
-1
lines changed

9 files changed

+34
-1
lines changed

dbldatagen/constraints/chained_relation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def __init__(self, columns, relation):
3838
if not isinstance(self._columns, list) or len(self._columns) <= 1:
3939
raise ValueError("ChainedRelation constraints must be defined across more than one column")
4040

41+
@classmethod
42+
def getMapping(cls):
43+
return {"relation": "_relation", "columns": "_columns"}
44+
4145
def _generateFilterExpression(self):
4246
""" Generated composite filter expression for chained set of filter expressions
4347

dbldatagen/constraints/constraint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import types
99
from abc import ABC, abstractmethod
1010
from pyspark.sql import Column
11+
from ..serialization import Serializable
1112

1213

13-
class Constraint(ABC):
14+
class Constraint(Serializable, ABC):
1415
""" Constraint object - base class for predefined and custom constraints
1516
1617
This class is meant for internal use only.

dbldatagen/constraints/literal_range_constraint.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def __init__(self, columns, lowValue, highValue, strict=False):
2929
self._highValue = highValue
3030
self._strict = strict
3131

32+
@classmethod
33+
def getMapping(cls):
34+
return {"columns": "_columns", "lowValue": "_lowValue", "highValue": "_highValue", "strict": "_strict"}
35+
3236
def _generateFilterExpression(self):
3337
""" Generate a SQL filter expression that may be used for filtering"""
3438
expressions = [F.col(colname) for colname in self._columns]

dbldatagen/constraints/literal_relation_constraint.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def __init__(self, columns, relation, value):
2929
if relation not in self.SUPPORTED_OPERATORS:
3030
raise ValueError(f"Parameter `relation` should be one of the operators :{self.SUPPORTED_OPERATORS}")
3131

32+
@classmethod
33+
def getMapping(cls):
34+
return {"columns": "_columns", "relation": "_relation", "value": "_value"}
35+
3236
def _generateFilterExpression(self):
3337
expressions = [F.col(colname) for colname in self._columns]
3438
literalValue = F.lit(self._value)

dbldatagen/constraints/negative_values.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ def __init__(self, columns, strict=False):
2727
self._columns = self._columnsFromListOrString(columns)
2828
self._strict = strict
2929

30+
@classmethod
31+
def getMapping(cls):
32+
return {"columns": "_columns", "strict": "_strict"}
33+
3034
def _generateFilterExpression(self):
3135
expressions = [F.col(colname) for colname in self._columns]
3236
if self._strict:

dbldatagen/constraints/positive_values.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ def __init__(self, columns, strict=False):
2727
self._columns = self._columnsFromListOrString(columns)
2828
self._strict = strict
2929

30+
@classmethod
31+
def getMapping(cls):
32+
return {"columns": "_columns", "strict": "_strict"}
33+
3034
def _generateFilterExpression(self):
3135
""" Generate a filter expression that may be used for filtering"""
3236
expressions = [F.col(colname) for colname in self._columns]

dbldatagen/constraints/ranged_values_constraint.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ def __init__(self, columns, lowValue, highValue, strict=False):
2828
self._highValue = highValue
2929
self._strict = strict
3030

31+
@classmethod
32+
def getMapping(cls):
33+
return {"columns": "_columns", "lowValue": "_lowValue", "highValue": "_highValue", "strict": "_strict"}
34+
3135
def _generateFilterExpression(self):
3236
""" Generate a SQL filter expression that may be used for filtering"""
3337
expressions = [F.col(colname) for colname in self._columns]

dbldatagen/constraints/sql_expr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def __init__(self, expr: str):
2525
assert isinstance(expr, str) and len(expr.strip()) > 0, "Expression must be a valid SQL string"
2626
self._expr = expr
2727

28+
@classmethod
29+
def getMapping(cls):
30+
return {"expr": "_expr"}
31+
2832
def _generateFilterExpression(self):
2933
""" Generate a SQL filter expression that may be used for filtering"""
3034
return F.expr(self._expr)

dbldatagen/constraints/unique_combinations.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ def __init__(self, columns=None):
4545
else:
4646
self._columns = None
4747

48+
@classmethod
49+
def getMapping(cls):
50+
return {"columns": "_columns"}
51+
4852
def prepareDataGenerator(self, dataGenerator):
4953
""" Prepare the data generator to generate data that matches the constraint
5054

0 commit comments

Comments
 (0)