Skip to content

Commit 0d0f4c2

Browse files
committed
Add serialization for distributions
1 parent 612618d commit 0d0f4c2

File tree

5 files changed

+23
-1
lines changed

5 files changed

+23
-1
lines changed

dbldatagen/distributions/beta.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def __init__(self, alpha=None, beta=None):
3535
self._alpha = alpha
3636
self._beta = beta
3737

38+
@classmethod
39+
def getMapping(cls):
40+
return {"alpha": "_alpha", "beta": "_beta"}
41+
3842
@property
3943
def alpha(self):
4044
""" Return alpha parameter."""

dbldatagen/distributions/data_distribution.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,20 @@
2525
import numpy as np
2626
import pyspark.sql.functions as F
2727

28+
from ..serialization import Serializable
2829

29-
class DataDistribution(ABC):
30+
31+
class DataDistribution(Serializable, ABC):
3032
""" Base class for all distributions"""
3133

3234
def __init__(self):
3335
self._rounding = False
3436
self._randomSeed = None
3537

38+
@classmethod
39+
def getMapping(cls):
40+
raise NotImplementedError("method not implemented")
41+
3642
@staticmethod
3743
def get_np_random_generator(random_seed):
3844
""" Get numpy random number generator

dbldatagen/distributions/exponential_distribution.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def __init__(self, rate=None):
3030
DataDistribution.__init__(self)
3131
self._rate = rate
3232

33+
@classmethod
34+
def getMapping(cls):
35+
return {"rate": "_rate"}
36+
3337
def __str__(self):
3438
""" Return string representation"""
3539
return f"ExponentialDistribution(rate={self.rate}, randomSeed={self.randomSeed})"

dbldatagen/distributions/gamma.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ def __init__(self, shape, scale):
3434
self._shape = shape
3535
self._scale = scale
3636

37+
@classmethod
38+
def getMapping(cls):
39+
return {"shape": "_shape", "scale": "_scale"}
40+
3741
@property
3842
def shape(self):
3943
""" Return shape parameter."""

dbldatagen/distributions/normal_distribution.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ def __init__(self, mean, stddev):
2727
self.mean = mean if mean is not None else 0.0
2828
self.stddev = stddev if stddev is not None else 1.0
2929

30+
@classmethod
31+
def getMapping(cls):
32+
return {"mean": "mean", "stddev": "stddev"}
33+
3034
@staticmethod
3135
def normal_func(mean_series: pd.Series, std_dev_series: pd.Series, random_seed: pd.Series) -> pd.Series:
3236
""" Pandas / Numpy based function to generate normal / gaussian samples

0 commit comments

Comments
 (0)