Skip to content

Commit c4d5182

Browse files
committed
Fix random unique value generation from a specified range
1 parent d97f39a commit c4d5182

File tree

2 files changed

+402
-18
lines changed

2 files changed

+402
-18
lines changed

dbldatagen/column_generation_spec.py

Lines changed: 172 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
import copy
1010
import logging
11+
import random
12+
from datetime import date, datetime, timedelta
1113

14+
import numpy as np
1215
from pyspark.sql.functions import col, pandas_udf
1316
from pyspark.sql.functions import lit, concat, rand, round as sql_round, array, expr, when, udf, \
1417
format_string
@@ -27,7 +30,7 @@
2730
from .nrange import NRange
2831
from .serialization import SerializableToDict
2932
from .text_generators import TemplateGenerator
30-
from .utils import ensure, coalesce_values
33+
from .utils import ensure, coalesce_values, parse_time_interval
3134
from .schema_parser import SchemaParser
3235

3336
HASH_COMPUTE_METHOD = "hash"
@@ -518,6 +521,130 @@ def _setupTemporaryColumns(self):
518521
'description': desc}))
519522
self._weightedBaseColumn = temp_name
520523

524+
def _list_random_unique_numeric_values(
525+
self,
526+
unique_count: int,
527+
min_val: int | float | date | datetime | str,
528+
max_val: int | float | date | datetime | str,
529+
step_val: int | float | timedelta | str
530+
) -> None:
531+
"""
532+
Builds a list of random unique numeric values when ``uniqueValues`` is specified and ``random=True``.
533+
534+
This creates an internal omitted column with a list of randomly selected unique values from the specified range,
535+
then sets up the main column to select from this list using a random index.
536+
537+
:param unique_count: Number of unique values to generate
538+
:param min_val: Minimum value of the range
539+
:param max_val: Maximum value of the range
540+
:param step_val: Step value for the range
541+
"""
542+
if self._randomSeed is not None and self._randomSeed != -1:
543+
self._set_random_seed()
544+
545+
selected_values = set()
546+
while len(selected_values) < unique_count:
547+
if self.distribution and isinstance(self.distribution, DataDistribution):
548+
raw_value = np.clip(self.distribution.generateNormalizedDistributionSample(), 0, 1)
549+
else:
550+
raw_value = random.random()
551+
552+
range_size = (max_val - min_val) / step_val
553+
if not isinstance(min_val, float) and not isinstance(max_val, float) and not isinstance(step_val, float):
554+
range_size = range_size + 1
555+
556+
scaled_index = int(raw_value * range_size)
557+
value = np.clip(min_val + scaled_index * step_val, min_val, max_val)
558+
selected_values.add(value)
559+
560+
selected_values = list(selected_values)
561+
if len(selected_values) < unique_count:
562+
self.logger.warning(
563+
f"Could not generate {unique_count} unique values for column {self.name}; "
564+
f"Generated {len(selected_values)} unique values"
565+
)
566+
567+
self.values = selected_values
568+
self.logger.info(
569+
f"Set up random unique values for column {self.name}: {len(selected_values)} values using "
570+
f"{'distribution' if self.distribution else 'uniform'} sampling"
571+
)
572+
573+
def _list_random_unique_datetime_values(
574+
self,
575+
unique_count: int,
576+
begin_val: date | datetime | str,
577+
end_val: date | datetime | str,
578+
interval_val: timedelta | str,
579+
col_type: DataType | str
580+
) -> None:
581+
"""
582+
Builds a list of random unique date/timestamp values when ``uniqueValues`` is specified and ``random=True``.
583+
584+
:param unique_count: Number of unique values to generate
585+
:param begin_val: Beginning date/timestamp
586+
:param end_val: End date/timestamp
587+
:param interval_val: Date/time interval
588+
:param col_type: Type of column to generate (e.g. ``DateType`` or ``TimestampType``)
589+
"""
590+
if isinstance(interval_val, str):
591+
interval_val = parse_time_interval(interval_val)
592+
if isinstance(begin_val, str):
593+
if isinstance(col_type, TimestampType):
594+
begin_val = datetime.strptime(begin_val, DateRange.DEFAULT_UTC_TS_FORMAT)
595+
else:
596+
begin_val = datetime.strptime(begin_val, DateRange.DEFAULT_DATE_FORMAT)
597+
if isinstance(end_val, str):
598+
if isinstance(col_type, TimestampType):
599+
end_val = datetime.strptime(end_val, DateRange.DEFAULT_UTC_TS_FORMAT)
600+
else:
601+
end_val = datetime.strptime(end_val, DateRange.DEFAULT_DATE_FORMAT)
602+
if isinstance(col_type, DateType):
603+
begin_val = begin_val.date()
604+
end_val = end_val.date()
605+
606+
total_span = end_val - begin_val
607+
if isinstance(total_span, timedelta):
608+
total_seconds = total_span.total_seconds()
609+
interval_seconds = interval_val.total_seconds()
610+
num_possible_values = int(total_seconds / interval_seconds) + 1
611+
else:
612+
total_days = total_span.days
613+
interval_days = interval_val.days
614+
num_possible_values = int(total_days / interval_days) + 1
615+
616+
unique_count = min(unique_count, num_possible_values)
617+
618+
if self._randomSeed is not None and self._randomSeed != -1:
619+
self._set_random_seed()
620+
621+
selected_values = set()
622+
while len(selected_values) < unique_count:
623+
if self.distribution and isinstance(self.distribution, DataDistribution):
624+
raw_value = np.clip(self.distribution.generateNormalizedDistributionSample(), 0, 1)
625+
else:
626+
raw_value = random.random()
627+
628+
scaled_index = int(raw_value * (num_possible_values - 1))
629+
value = begin_val + interval_val * scaled_index
630+
631+
if value > end_val:
632+
value = end_val
633+
selected_values.add(value)
634+
635+
selected_values = list(selected_values)
636+
if len(selected_values) < unique_count:
637+
self.logger.warning(
638+
f"Could not generate {unique_count} unique values for column {self.name}; "
639+
f"Generated {len(selected_values)} unique values"
640+
)
641+
642+
self.values = selected_values
643+
self.logger.info(
644+
f"Set up random unique values for column {self.name}: {len(selected_values)} values using "
645+
f"{'distribution' if self.distribution else 'uniform'} sampling"
646+
)
647+
521648
def _setup_logger(self):
522649
"""Set up logging
523650
@@ -553,12 +680,12 @@ def _computeAdjustedNumericRangeForColumn(self, colType, c_min, c_max, c_step, *
553680
- if a datarange is specified , use that range
554681
- if begin and end are specified or minValue and maxValue are specified, use that
555682
- if unique values is specified, compute minValue and maxValue depending on type
556-
683+
- if unique values and random=True are both specified, generate random unique values from full range
557684
"""
558685
if c_unique is not None:
559686
assert type(c_unique) is int, "unique_values must be integer"
560687
assert c_unique >= 1, "if supplied, unique values must be > 0"
561-
# TODO: set maxValue to unique_values + minValue & add unit test
688+
562689
effective_min, effective_max, effective_step = None, None, None
563690
if c_range is not None and type(c_range) is NRange:
564691
effective_min = c_range.minValue
@@ -568,19 +695,27 @@ def _computeAdjustedNumericRangeForColumn(self, colType, c_min, c_max, c_step, *
568695
effective_step = coalesce_values(effective_step, c_step, 1)
569696
effective_max = coalesce_values(effective_max, c_max)
570697

571-
# due to floating point errors in some Python floating point calculations, we need to apply rounding
572-
# if any of the components are float
573-
if type(effective_min) is float or type(effective_step) is float:
574-
unique_max = round(c_unique * effective_step + effective_min - effective_step, 9)
698+
# Check if both uniqueValues and random=True are specified
699+
if self.random and effective_max is not None:
700+
# Generate random unique values from the full range and store them
701+
self._list_random_unique_numeric_values(c_unique, effective_min, effective_max, effective_step)
702+
# Create a range that maps to indices of the unique values (0 to unique_count-1)
703+
result = NRange(0, c_unique - 1, 1)
575704
else:
576-
unique_max = c_unique * effective_step + effective_min - effective_step
577-
result = NRange(effective_min, unique_max, effective_step)
578-
579-
if result.maxValue is not None and effective_max is not None and result.maxValue > effective_max:
580-
self.logger.warning("Computed maxValue for column [%s] of %s is greater than specified maxValue %s",
581-
self.name,
582-
result.maxValue,
583-
effective_max)
705+
# Original behavior: create sequential range
706+
# due to floating point errors in some Python floating point calculations, we need to apply rounding
707+
# if any of the components are float
708+
if type(effective_min) is float or type(effective_step) is float:
709+
unique_max = round(c_unique * effective_step + effective_min - effective_step, 9)
710+
else:
711+
unique_max = c_unique * effective_step + effective_min - effective_step
712+
result = NRange(effective_min, unique_max, effective_step)
713+
714+
if result.maxValue is not None and effective_max is not None and result.maxValue > effective_max:
715+
self.logger.warning("Computed maxValue for column [%s] of %s is greater than specified maxValue %s",
716+
self.name,
717+
result.maxValue,
718+
effective_max)
584719
elif c_range is not None:
585720
result = c_range
586721
elif c_range is None:
@@ -607,10 +742,21 @@ def _computeAdjustedDateTimeRangeForColumn(self, colType, c_begin, c_end, c_inte
607742
effective_end = coalesce_values(effective_end, c_end)
608743
effective_begin = coalesce_values(effective_begin, c_begin)
609744

610-
if type(colType) is DateType:
611-
result = DateRange.computeDateRange(effective_begin, effective_end, effective_interval, c_unique)
745+
# Check if both uniqueValues and random=True are specified for date/timestamp
746+
if c_unique is not None and self.random and effective_end is not None:
747+
# Generate random unique date/timestamp values from the full range
748+
self._list_random_unique_datetime_values(c_unique, effective_begin, effective_end, effective_interval, colType)
749+
# Return a minimal range - the actual values will come from discrete values
750+
if type(colType) is DateType:
751+
result = DateRange.computeDateRange(effective_begin, effective_begin, effective_interval, 1)
752+
else:
753+
result = DateRange.computeTimestampRange(effective_begin, effective_begin, effective_interval, 1)
612754
else:
613-
result = DateRange.computeTimestampRange(effective_begin, effective_end, effective_interval, c_unique)
755+
# Original behavior
756+
if type(colType) is DateType:
757+
result = DateRange.computeDateRange(effective_begin, effective_end, effective_interval, c_unique)
758+
else:
759+
result = DateRange.computeTimestampRange(effective_begin, effective_end, effective_interval, c_unique)
614760

615761
self.logger.debug("Computing adjusted range for column: %s - %s", self.name, result)
616762
return result
@@ -1322,3 +1468,11 @@ def makeGenerationExpressions(self):
13221468
retval = F.slice(retval, F.lit(1), F.expr(expr_str))
13231469

13241470
return retval
1471+
1472+
def _set_random_seed(self) -> None:
1473+
"""
1474+
Sets the random seed value for computing random values from a range.
1475+
"""
1476+
seed_value = abs(self._randomSeed) % (2**32) # Numpy accepts values in the range from 0 - 2^32-1.
1477+
random.seed(seed_value)
1478+
np.random.seed(seed_value)

0 commit comments

Comments
 (0)