11# See the License for the specific language governing permissions and
22# limitations under the License.
33#
4+
5+ """
6+ This file defines the `ColumnGeneratorBuilder` class and utility functions
7+ """
8+
49import itertools
10+ from typing import Any
511
6- from pyspark .sql .types import StringType , DateType , TimestampType
12+ from pyspark .sql .types import DataType , DateType , StringType , TimestampType
713
814
915class ColumnGeneratorBuilder :
10- """ Helper class to build functional column generators of specific forms"""
16+ """
17+ Helper class to build functional column generators of specific forms
18+ """
1119
1220 @classmethod
13- def _mkList (cls , x ):
21+ def _mkList (cls , x : object ) -> list :
22+ """
23+ Makes a list of the supplied object instance if it is not already a list.
24+
25+ :param x: Input object to process
26+ :returns: List containing the supplied object if it is not already a list; otherwise returns the object
1427 """
15- Makes a list of the supplied object instance if it is not already a list
16- :param x: object to process
17- :returns: Returns list of supplied object if it is not already a list, otherwise simply returns the object"""
1828 return [x ] if type (x ) is not list else x
1929
2030 @classmethod
21- def _lastElement (cls , x ):
22- """ Gets the last element, if the object is a list otherwise returns the object itself"""
23- return x [- 1 ] if type (x ) is list else x
31+ def _lastElement (cls , x : object ) -> object :
32+ """
33+ Gets the last element from the supplied object if it is a list.
34+
35+ :param x: Input object
36+ :returns: Last element of the input object if it is a list; otherwise returns the object
37+ """
38+ return x [- 1 ] if isinstance (x , list ) else x
2439
2540 @classmethod
26- def _mkCdfProbabilities (cls , weights ):
27- """ make cumulative distribution function probabilities for each value in values list
41+ def _mkCdfProbabilities (cls , weights : list [float ]) -> list [float ]:
42+ """
43+ Makes cumulative distribution function probabilities for each value in values list.
2844
2945 a cumulative distribution function for discrete values can uses
3046 a table of cumulative probabilities to evaluate different expressions
@@ -46,6 +62,9 @@ def _mkCdfProbabilities(cls, weights):
4662 while datasets of size 10,000 x `number of values` gives a repeated
4763 distribution within 5% of expected distribution.
4864
65+ :param weights: List of weights to compute CDF probabilities for
66+ :returns: List of CDF probabilities
67+
4968 Example code to be generated (pseudo code)::
5069
5170 # given values value1 .. valueN, prob1 to probN
@@ -61,13 +80,12 @@ def _mkCdfProbabilities(cls, weights):
6180
6281 """
6382 total_weights = sum (weights )
64- return list ( map ( lambda x : x / total_weights , itertools .accumulate (weights )))
83+ return [ x / total_weights for x in itertools .accumulate (weights )]
6584
6685 @classmethod
67- def mkExprChoicesFn (cls , values , weights , seed_column , datatype ):
68- """ Create SQL expression to compute the weighted values expression
69-
70- build an expression of the form::
86+ def mkExprChoicesFn (cls , values : list [Any ], weights : list [float ], seed_column : str , datatype : DataType ) -> str :
87+ """
88+ Creates a SQL expression to compute a weighted values expression. Builds an expression of the form::
7189
7290 case
7391 when rnd_column <= weight1 then value1
@@ -77,22 +95,22 @@ def mkExprChoicesFn(cls, values, weights, seed_column, datatype):
7795 else valueN
7896 end
7997
80- based on computed probability distribution for values.
81-
82- In Python 3.6 onwards, we could use the choices function but this python version is not
83- guaranteed on all Databricks distributions
98+ The output expression is based on the computed probability distribution for the specified values.
8499
85- :param values: list of values
86- :param weights: list of weights
87- :param seed_column: base column for expression
88- :param datatype: data type of function return value
100+ In Python 3.6 onwards, we could use the choices function but this python version is not guaranteed on all
101+ Databricks distributions.
89102
103+ :param values: List of values
104+ :param weights: List of weights
105+ :param seed_column: Base column name for expression
106+ :param datatype: Spark `DataType` of the output expression
107+ :returns: SQL expression representing the weighted values
90108 """
91109 cdf_probs = cls ._mkCdfProbabilities (weights )
92110
93111 output = [" CASE " ]
94112
95- conditions = zip (values , cdf_probs )
113+ conditions = zip (values , cdf_probs , strict = False )
96114
97115 for v , cdf in conditions :
98116 # TODO(alex): single quotes needs to be escaped
0 commit comments