From 20db964fca51d8e170cf4009433e09fd9e3b3397 Mon Sep 17 00:00:00 2001 From: Anup Kalburgi Date: Tue, 23 Sep 2025 12:16:41 -0400 Subject: [PATCH 01/20] working test wireup --- dbldatagen/spec/compat.py | 18 ++++++++++ dbldatagen/spec/generator.py | 69 ++++++++++++++++++++++++++++++++++++ makefile | 4 +-- pyproject.toml | 29 ++++++++++----- tests/test_spec.py | 7 ++++ 5 files changed, 117 insertions(+), 10 deletions(-) create mode 100644 dbldatagen/spec/compat.py create mode 100644 dbldatagen/spec/generator.py create mode 100644 tests/test_spec.py diff --git a/dbldatagen/spec/compat.py b/dbldatagen/spec/compat.py new file mode 100644 index 00000000..68649b36 --- /dev/null +++ b/dbldatagen/spec/compat.py @@ -0,0 +1,18 @@ +# This module acts as a compatibility layer for Pydantic V1 and V2. + +try: + # This will succeed on environments with Pydantic V2.x + # It imports the V1 API that is bundled within V2. + from pydantic.v1 import BaseModel, Field, validator, constr + +except ImportError: + # This will be executed on environments with only Pydantic V1.x + from pydantic import BaseModel, Field, validator, constr + +# In your application code, do this: +# from .compat import BaseModel +# NOT this: +# from pydantic import BaseModel + +# FastAPI Notes +# https://github.com/fastapi/fastapi/blob/master/fastapi/_compat.py \ No newline at end of file diff --git a/dbldatagen/spec/generator.py b/dbldatagen/spec/generator.py new file mode 100644 index 00000000..6039165f --- /dev/null +++ b/dbldatagen/spec/generator.py @@ -0,0 +1,69 @@ +from .compat import BaseModel +from typing import Dict, Optional, Union, Any + + +# class ColumnDefinition(BaseModel): +# name: str +# type: Optional[DbldatagenBasicType] = None +# primary: bool = False +# options: Optional[Dict[str, Any]] = {} +# nullable: Optional[bool] = False +# omit: Optional[bool] = False +# baseColumn: Optional[str] = "id" +# baseColumnType: Optional[str] = "auto" + +# @model_validator(mode="after") +# def check_constraints(self): +# if self.primary: +# if "min" in self.options or "max" in self.options: +# raise ValueError( +# f"Primary column '{self.name}' cannot have min/max options.") +# if self.nullable: +# raise ValueError( +# f"Primary column '{self.name}' cannot be nullable.") +# if self.primary and self.type is None: +# raise ValueError( +# f"Primary column '{self.name}' must have a type defined.") +# return self + + +# class TableDefinition(BaseModel): +# number_of_rows: int +# partitions: Optional[int] = None +# columns: List[ColumnDefinition] + + +# class DatagenSpec(BaseModel): +# tables: Dict[str, TableDefinition] +# output_destination: Optional[Union[UCSchemaTarget, FilePathTarget]] = None +# generator_options: Optional[Dict[str, Any]] = {} +# intended_for_databricks: Optional[bool] = None + + + +# def display_all_tables(self): +# for table_name, table_def in self.tables.items(): +# print(f"Table: {table_name}") + +# if self.output_destination: +# output = f"{self.output_destination}" +# display(HTML(f"Output destination: {output}")) +# else: +# message = ( +# "Output destination: " +# "None
" +# "Set it using the output_destination " +# "attribute on your DatagenSpec object " +# "(e.g., my_spec.output_destination = UCSchemaTarget(...))." +# ) +# display(HTML(message)) + +# df = pd.DataFrame([col.dict() for col in table_def.columns]) +# try: +# display(df) +# except NameError: +# print(df.to_string()) + + +class DatagenSpec(BaseModel): + name: str \ No newline at end of file diff --git a/makefile b/makefile index 772397bf..a5f4486c 100644 --- a/makefile +++ b/makefile @@ -8,7 +8,7 @@ clean: .venv/bin/python: pip install hatch - hatch env create + hatch env create test-pydantic.pydantic==1.10.6-v1 dev: .venv/bin/python @hatch run which python @@ -20,7 +20,7 @@ fmt: hatch run fmt test: - hatch run test + hatch run test-pydantic:test test-coverage: make test && open htmlcov/index.html diff --git a/pyproject.toml b/pyproject.toml index 13728ba2..304562e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,18 +106,31 @@ dependencies = [ ] python="3.10" - -# store virtual env as the child of this folder. Helps VSCode (and PyCharm) to run better path = ".venv" [tool.hatch.envs.default.scripts] test = "pytest tests/ -n 10 --cov --cov-report=html --timeout 600 --durations 20" -fmt = ["ruff check . --fix", - "mypy .", - "pylint --output-format=colorized -j 0 dbldatagen tests"] -verify = ["ruff check .", - "mypy .", - "pylint --output-format=colorized -j 0 dbldatagen tests"] +fmt = [ + "ruff check . --fix", + "mypy .", + "pylint --output-format=colorized -j 0 dbldatagen tests" +] +verify = [ + "ruff check .", + "mypy .", + "pylint --output-format=colorized -j 0 dbldatagen tests" +] + + +[tool.hatch.envs.test-pydantic] +template = "default" +matrix = [ + { pydantic_version = ["1.10.6", "2.8.2"] } +] +extra-dependencies = [ + "pydantic=={matrix:pydantic_version}" +] + # Ruff configuration - replaces flake8, isort, pydocstyle, etc. [tool.ruff] diff --git a/tests/test_spec.py b/tests/test_spec.py new file mode 100644 index 00000000..2e12cb27 --- /dev/null +++ b/tests/test_spec.py @@ -0,0 +1,7 @@ +from dbldatagen.spec.generator import DatagenSpec + +def test_spec(): + spec = DatagenSpec(name="test_spec") + assert spec.name == "test_spec" + + From 4e9cba5272586692337ea2f29574a428ff3e585d Mon Sep 17 00:00:00 2001 From: Anup Kalburgi Date: Tue, 4 Nov 2025 09:10:38 -0500 Subject: [PATCH 02/20] Initial code, spec and test, pushing for review --- dbldatagen/spec/compat.py | 17 +- dbldatagen/spec/generator.py | 69 ---- dbldatagen/spec/generator_spec.py | 324 +++++++++++++++++ dbldatagen/spec/generator_spec_impl.py | 254 ++++++++++++++ pydantic_compat.md | 101 ++++++ scratch.md | 4 + tests/test_spec.py | 7 - tests/test_specs.py | 466 +++++++++++++++++++++++++ 8 files changed, 1164 insertions(+), 78 deletions(-) delete mode 100644 dbldatagen/spec/generator.py create mode 100644 dbldatagen/spec/generator_spec.py create mode 100644 dbldatagen/spec/generator_spec_impl.py create mode 100644 pydantic_compat.md create mode 100644 scratch.md delete mode 100644 tests/test_spec.py create mode 100644 tests/test_specs.py diff --git a/dbldatagen/spec/compat.py b/dbldatagen/spec/compat.py index 68649b36..7c30d57d 100644 --- a/dbldatagen/spec/compat.py +++ b/dbldatagen/spec/compat.py @@ -7,7 +7,7 @@ except ImportError: # This will be executed on environments with only Pydantic V1.x - from pydantic import BaseModel, Field, validator, constr + from pydantic import BaseModel, Field, validator, constr, root_validator, field_validator # In your application code, do this: # from .compat import BaseModel @@ -15,4 +15,17 @@ # from pydantic import BaseModel # FastAPI Notes -# https://github.com/fastapi/fastapi/blob/master/fastapi/_compat.py \ No newline at end of file +# https://github.com/fastapi/fastapi/blob/master/fastapi/_compat.py + + +""" +## Why This Approach +No Installation Required: It directly addresses your core requirement. +You don't need to %pip install anything, which avoids conflicts with the pre-installed libraries on Databricks. +Single Codebase: You maintain one set of code that is guaranteed to work with the Pydantic V1 API, which is available in both runtimes. + +Environment Agnostic: Your application code in models.py has no idea which version of Pydantic is actually installed. The compat.py module handles that complexity completely. + +Future-Ready: When you eventually decide to migrate fully to the Pydantic V2 API (to take advantage of its speed and features), +you only need to change your application code and your compat.py import statements, making the transition much clearer. +""" \ No newline at end of file diff --git a/dbldatagen/spec/generator.py b/dbldatagen/spec/generator.py deleted file mode 100644 index 6039165f..00000000 --- a/dbldatagen/spec/generator.py +++ /dev/null @@ -1,69 +0,0 @@ -from .compat import BaseModel -from typing import Dict, Optional, Union, Any - - -# class ColumnDefinition(BaseModel): -# name: str -# type: Optional[DbldatagenBasicType] = None -# primary: bool = False -# options: Optional[Dict[str, Any]] = {} -# nullable: Optional[bool] = False -# omit: Optional[bool] = False -# baseColumn: Optional[str] = "id" -# baseColumnType: Optional[str] = "auto" - -# @model_validator(mode="after") -# def check_constraints(self): -# if self.primary: -# if "min" in self.options or "max" in self.options: -# raise ValueError( -# f"Primary column '{self.name}' cannot have min/max options.") -# if self.nullable: -# raise ValueError( -# f"Primary column '{self.name}' cannot be nullable.") -# if self.primary and self.type is None: -# raise ValueError( -# f"Primary column '{self.name}' must have a type defined.") -# return self - - -# class TableDefinition(BaseModel): -# number_of_rows: int -# partitions: Optional[int] = None -# columns: List[ColumnDefinition] - - -# class DatagenSpec(BaseModel): -# tables: Dict[str, TableDefinition] -# output_destination: Optional[Union[UCSchemaTarget, FilePathTarget]] = None -# generator_options: Optional[Dict[str, Any]] = {} -# intended_for_databricks: Optional[bool] = None - - - -# def display_all_tables(self): -# for table_name, table_def in self.tables.items(): -# print(f"Table: {table_name}") - -# if self.output_destination: -# output = f"{self.output_destination}" -# display(HTML(f"Output destination: {output}")) -# else: -# message = ( -# "Output destination: " -# "None
" -# "Set it using the output_destination " -# "attribute on your DatagenSpec object " -# "(e.g., my_spec.output_destination = UCSchemaTarget(...))." -# ) -# display(HTML(message)) - -# df = pd.DataFrame([col.dict() for col in table_def.columns]) -# try: -# display(df) -# except NameError: -# print(df.to_string()) - - -class DatagenSpec(BaseModel): - name: str \ No newline at end of file diff --git a/dbldatagen/spec/generator_spec.py b/dbldatagen/spec/generator_spec.py new file mode 100644 index 00000000..23afc4a0 --- /dev/null +++ b/dbldatagen/spec/generator_spec.py @@ -0,0 +1,324 @@ +from .compat import BaseModel, validator, root_validator, field_validator +from typing import Dict, Optional, Union, Any, Literal, List +import pandas as pd +from IPython.display import display, HTML + +DbldatagenBasicType = Literal[ + "string", + "int", + "long", + "float", + "double", + "decimal", + "boolean", + "date", + "timestamp", + "short", + "byte", + "binary", + "integer", + "bigint", + "tinyint", +] + +class ColumnDefinition(BaseModel): + name: str + type: Optional[DbldatagenBasicType] = None + primary: bool = False + options: Optional[Dict[str, Any]] = {} + nullable: Optional[bool] = False + omit: Optional[bool] = False + baseColumn: Optional[str] = "id" + baseColumnType: Optional[str] = "auto" + + @root_validator(skip_on_failure=True) + def check_model_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """ + Validates constraints across the entire model after individual fields are processed. + """ + is_primary = values.get("primary") + options = values.get("options", {}) + name = values.get("name") + is_nullable = values.get("nullable") + column_type = values.get("type") + + if is_primary: + if "min" in options or "max" in options: + raise ValueError(f"Primary column '{name}' cannot have min/max options.") + + if is_nullable: + raise ValueError(f"Primary column '{name}' cannot be nullable.") + + if column_type is None: + raise ValueError(f"Primary column '{name}' must have a type defined.") + return values + + +class UCSchemaTarget(BaseModel): + catalog: str + schema_: str + output_format: str = "delta" # Default to delta for UC Schema + + @field_validator("catalog", "schema_", mode="after") + def validate_identifiers(cls, v): # noqa: N805, pylint: disable=no-self-argument + if not v.strip(): + raise ValueError("Identifier must be non-empty.") + if not v.isidentifier(): + logger.warning( + f"'{v}' is not a basic Python identifier. Ensure validity for Unity Catalog.") + return v.strip() + + def __str__(self): + return f"{self.catalog}.{self.schema_} (Format: {self.output_format}, Type: UC Table)" + + +class FilePathTarget(BaseModel): + base_path: str + output_format: Literal["csv", "parquet"] # No default, must be specified + + @field_validator("base_path", mode="after") + def validate_base_path(cls, v): # noqa: N805, pylint: disable=no-self-argument + if not v.strip(): + raise ValueError("base_path must be non-empty.") + return v.strip() + + def __str__(self): + return f"{self.base_path} (Format: {self.output_format}, Type: File Path)" + + +class TableDefinition(BaseModel): + number_of_rows: int + partitions: Optional[int] = None + columns: List[ColumnDefinition] + + +class ValidationResult: + """Container for validation results with errors and warnings.""" + + def __init__(self) -> None: + self.errors: List[str] = [] + self.warnings: List[str] = [] + + def add_error(self, message: str) -> None: + """Add an error message.""" + self.errors.append(message) + + def add_warning(self, message: str) -> None: + """Add a warning message.""" + self.warnings.append(message) + + def is_valid(self) -> bool: + """Returns True if there are no errors.""" + return len(self.errors) == 0 + + def __str__(self) -> str: + """String representation of validation results.""" + lines = [] + if self.is_valid(): + lines.append("✓ Validation passed successfully") + else: + lines.append("✗ Validation failed") + + if self.errors: + lines.append(f"\nErrors ({len(self.errors)}):") + for i, error in enumerate(self.errors, 1): + lines.append(f" {i}. {error}") + + if self.warnings: + lines.append(f"\nWarnings ({len(self.warnings)}):") + for i, warning in enumerate(self.warnings, 1): + lines.append(f" {i}. {warning}") + + return "\n".join(lines) + +class DatagenSpec(BaseModel): + tables: Dict[str, TableDefinition] + output_destination: Optional[Union[UCSchemaTarget, FilePathTarget]] = None # there is a abstraction, may be we can use that? talk to Greg + generator_options: Optional[Dict[str, Any]] = {} + intended_for_databricks: Optional[bool] = None # May be infered. + + def _check_circular_dependencies( + self, + table_name: str, + columns: List[ColumnDefinition] + ) -> List[str]: + """ + Check for circular dependencies in baseColumn references. + Returns a list of error messages if circular dependencies are found. + """ + errors = [] + column_map = {col.name: col for col in columns} + + for col in columns: + if col.baseColumn and col.baseColumn != "id": + # Track the dependency chain + visited = set() + current = col.name + + while current: + if current in visited: + # Found a cycle + cycle_path = " -> ".join(list(visited) + [current]) + errors.append( + f"Table '{table_name}': Circular dependency detected in column '{col.name}': {cycle_path}" + ) + break + + visited.add(current) + current_col = column_map.get(current) + + if not current_col: + break + + # Move to the next column in the chain + if current_col.baseColumn and current_col.baseColumn != "id": + if current_col.baseColumn not in column_map: + # baseColumn doesn't exist - we'll catch this in another validation + break + current = current_col.baseColumn + else: + # Reached a column that doesn't have a baseColumn or uses "id" + break + + return errors + + def validate(self, strict: bool = True) -> ValidationResult: + """ + Validates the entire DatagenSpec configuration. + Always runs all validation checks and collects all errors and warnings. + + Args: + strict: If True, raises ValueError if any errors or warnings are found. + If False, only raises ValueError if errors (not warnings) are found. + + Returns: + ValidationResult object containing all errors and warnings found. + + Raises: + ValueError: If validation fails based on strict mode setting. + The exception message contains all errors and warnings. + """ + result = ValidationResult() + + # 1. Check that there's at least one table + if not self.tables: + result.add_error("Spec must contain at least one table definition") + + # 2. Validate each table (continue checking all tables even if errors found) + for table_name, table_def in self.tables.items(): + # Check table has at least one column + if not table_def.columns: + result.add_error(f"Table '{table_name}' must have at least one column") + continue # Skip further checks for this table since it has no columns + + # Check row count is positive + if table_def.number_of_rows <= 0: + result.add_error( + f"Table '{table_name}' has invalid number_of_rows: {table_def.number_of_rows}. " + "Must be a positive integer." + ) + + # Check partitions if specified + #TODO: though this can be a model field check, we are checking here so that one can correct + # Can we find a way to use the default way? + if table_def.partitions is not None and table_def.partitions <= 0: + result.add_error( + f"Table '{table_name}' has invalid partitions: {table_def.partitions}. " + "Must be a positive integer or None." + ) + + # Check for duplicate column names + # TODO: Not something possible if we right model, recheck + column_names = [col.name for col in table_def.columns] + duplicates = [name for name in set(column_names) if column_names.count(name) > 1] + if duplicates: + result.add_error( + f"Table '{table_name}' has duplicate column names: {', '.join(duplicates)}" + ) + + # Build column map for reference checking + column_map = {col.name: col for col in table_def.columns} + + # TODO: Check baseColumn references, this is tricky? check the dbldefaults + for col in table_def.columns: + if col.baseColumn and col.baseColumn != "id": + if col.baseColumn not in column_map: + result.add_error( + f"Table '{table_name}', column '{col.name}': " + f"baseColumn '{col.baseColumn}' does not exist in the table" + ) + + # Check for circular dependencies in baseColumn references + circular_errors = self._check_circular_dependencies(table_name, table_def.columns) + for error in circular_errors: + result.add_error(error) + + # Check primary key constraints + primary_columns = [col for col in table_def.columns if col.primary] + if len(primary_columns) > 1: + primary_names = [col.name for col in primary_columns] + result.add_warning( + f"Table '{table_name}' has multiple primary columns: {', '.join(primary_names)}. " + "This may not be the intended behavior." + ) + + # Check for columns with no type and not using baseColumn properly + for col in table_def.columns: + if not col.primary and not col.type and not col.options: + result.add_warning( + f"Table '{table_name}', column '{col.name}': " + "No type specified and no options provided. " + "Column may not generate data as expected." + ) + + # 3. Check output destination + if not self.output_destination: + result.add_warning( + "No output_destination specified. Data will be generated but not persisted. " + "Set output_destination to save generated data." + ) + + # 4. Validate generator options (if any known options) + if self.generator_options: + known_options = [ + "random", "randomSeed", "randomSeedMethod", "verbose", + "debug", "seedColumnName" + ] + for key in self.generator_options.keys(): + if key not in known_options: + result.add_warning( + f"Unknown generator option: '{key}'. " + "This may be ignored during generation." + ) + + # Now that all validations are complete, decide whether to raise + if strict and (result.errors or result.warnings): + raise ValueError(str(result)) + elif not strict and result.errors: + raise ValueError(str(result)) + + return result + + + def display_all_tables(self) -> None: + for table_name, table_def in self.tables.items(): + print(f"Table: {table_name}") + + if self.output_destination: + output = f"{self.output_destination}" + display(HTML(f"Output destination: {output}")) + else: + message = ( + "Output destination: " + "None
" + "Set it using the output_destination " + "attribute on your DatagenSpec object " + "(e.g., my_spec.output_destination = UCSchemaTarget(...))." + ) + display(HTML(message)) + + df = pd.DataFrame([col.dict() for col in table_def.columns]) + try: + display(df) + except NameError: + print(df.to_string()) diff --git a/dbldatagen/spec/generator_spec_impl.py b/dbldatagen/spec/generator_spec_impl.py new file mode 100644 index 00000000..a508b1a5 --- /dev/null +++ b/dbldatagen/spec/generator_spec_impl.py @@ -0,0 +1,254 @@ +import logging +from typing import Dict, Union +import posixpath + +from dbldatagen.spec.generator_spec import TableDefinition +from pyspark.sql import SparkSession +import dbldatagen as dg +from .generator_spec import DatagenSpec, UCSchemaTarget, FilePathTarget, ColumnDefinition + + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger(__name__) + +INTERNAL_ID_COLUMN_NAME = "id" + + +class Generator: + """ + Main data generation orchestrator that handles configuration, preparation, and writing of data. + """ + + def __init__(self, spark: SparkSession, app_name: str = "DataGen_ClassBased") -> None: + """ + Initialize the Generator with a SparkSession. + Args: + spark: An existing SparkSession instance + app_name: Application name for logging purposes + Raises: + RuntimeError: If spark is None + """ + if not spark: + logger.error( + "SparkSession cannot be None during Generator initialization") + raise RuntimeError("SparkSession cannot be None") + self.spark = spark + self._created_spark_session = False + self.app_name = app_name + logger.info("Generator initialized with SparkSession") + + def _columnspec_to_datagen_columnspec(self, col_def: ColumnDefinition) -> Dict[str, str]: + """ + Convert a ColumnDefinition to dbldatagen column specification. + Args: + col_def: ColumnDefinition object containing column configuration + Returns: + Dictionary containing dbldatagen column specification + """ + col_name = col_def.name + col_type = col_def.type + kwargs = col_def.options.copy() if col_def.options is not None else {} + + if col_def.primary: + kwargs["colType"] = col_type + kwargs["baseColumn"] = INTERNAL_ID_COLUMN_NAME + + if col_type == "string": + kwargs["baseColumnType"] = "hash" + elif col_type not in ["int", "long", "integer", "bigint", "short"]: + kwargs["baseColumnType"] = "auto" + logger.warning( + f"Primary key '{col_name}' has non-standard type '{col_type}'") + + # Log conflicting options for primary keys + conflicting_opts_for_pk = [ + "distribution", "template", "dataRange", "random", "omit", + "min", "max", "uniqueValues", "values", "expr" + ] + + for opt_key in conflicting_opts_for_pk: + if opt_key in kwargs: + logger.warning( + f"Primary key '{col_name}': Option '{opt_key}' may be ignored") + + if col_def.omit is not None and col_def.omit: + kwargs["omit"] = True + else: + kwargs = col_def.options.copy() if col_def.options is not None else {} + + if col_type: + kwargs["colType"] = col_type + if col_def.baseColumn: + kwargs["baseColumn"] = col_def.baseColumn + if col_def.baseColumnType: + kwargs["baseColumnType"] = col_def.baseColumnType + if col_def.omit is not None: + kwargs["omit"] = col_def.omit + + return kwargs + + def _prepare_data_generators( + self, + config: DatagenSpec, + config_source_name: str = "PydanticConfig" + ) -> Dict[str, dg.DataGenerator]: + """ + Prepare DataGenerator specifications for each table based on the configuration. + Args: + config: DatagenSpec Pydantic object containing table configurations + config_source_name: Name for the configuration source (for logging) + Returns: + Dictionary mapping table names to their configured dbldatagen.DataGenerator objects + Raises: + RuntimeError: If SparkSession is not available + ValueError: If any table preparation fails + Exception: If any unexpected error occurs during preparation + """ + logger.info( + f"Preparing data generators for {len(config.tables)} tables") + + if not self.spark: + logger.error( + "SparkSession is not available. Cannot prepare data generators") + raise RuntimeError( + "SparkSession is not available. Cannot prepare data generators") + + tables_config: Dict[str, TableDefinition] = config.tables + global_gen_options = config.generator_options if config.generator_options else {} + + prepared_generators: Dict[str, dg.DataGenerator] = {} + generation_order = list(tables_config.keys()) # This becomes impotant when we get into multitable + + for table_name in generation_order: + table_spec = tables_config[table_name] + logger.info(f"Preparing table: {table_name}") + + try: + # Create DataGenerator instance + data_gen = dg.DataGenerator( + sparkSession=self.spark, + name=f"{table_name}_spec_from_{config_source_name}", + rows=table_spec.number_of_rows, + partitions=table_spec.partitions, + **global_gen_options, + ) + + # Process each column + for col_def in table_spec.columns: + kwargs = self._columnspec_to_datagen_columnspec(col_def) + data_gen = data_gen.withColumn(colName=col_def.name, **kwargs) + # Has performance implications. + + prepared_generators[table_name] = data_gen + logger.info(f"Successfully prepared table: {table_name}") + + except Exception as e: + logger.error(f"Failed to prepare table '{table_name}': {e}") + raise RuntimeError( + f"Failed to prepare table '{table_name}': {e}") from e + + logger.info("All data generators prepared successfully") + return prepared_generators + + def write_prepared_data( + self, + prepared_generators: Dict[str, dg.DataGenerator], + output_destination: Union[UCSchemaTarget, FilePathTarget, None], + config_source_name: str = "PydanticConfig", + ) -> None: + """ + Write data from prepared generators to the specified output destination. + + Args: + prepared_generators: Dictionary of prepared DataGenerator objects + output_destination: Target destination for data output + config_source_name: Name for the configuration source (for logging) + + Raises: + RuntimeError: If any table write fails + ValueError: If output destination is not properly configured + """ + logger.info("Starting data writing phase") + + if not prepared_generators: + logger.warning("No prepared data generators to write") + return + + for table_name, data_gen in prepared_generators.items(): + logger.info(f"Writing table: {table_name}") + + try: + df = data_gen.build() + requested_rows = data_gen.rowCount + actual_row_count = df.count() + logger.info( + f"Built DataFrame for '{table_name}': {actual_row_count} rows (requested: {requested_rows})") + + if actual_row_count == 0 and requested_rows > 0: + logger.warning(f"Table '{table_name}': Requested {requested_rows} rows but built 0") + + # Write data based on destination type + if isinstance(output_destination, FilePathTarget): + output_path = posixpath.join(output_destination.base_path, table_name) + df.write.format(output_destination.output_format).mode("overwrite").save(output_path) + logger.info(f"Wrote table '{table_name}' to file path: {output_path}") + + elif isinstance(output_destination, UCSchemaTarget): + output_table = f"{output_destination.catalog}.{output_destination.schema_}.{table_name}" + df.write.mode("overwrite").saveAsTable(output_table) + logger.info(f"Wrote table '{table_name}' to Unity Catalog: {output_table}") + else: + logger.warning("No output destination specified, skipping data write") + return + except Exception as e: + logger.error(f"Failed to write table '{table_name}': {e}") + raise RuntimeError(f"Failed to write table '{table_name}': {e}") from e + logger.info("All data writes completed successfully") + + def generate_and_write_data( + self, + config: DatagenSpec, + config_source_name: str = "PydanticConfig" + ) -> None: + """ + Combined method to prepare data generators and write data in one operation. + This method orchestrates the complete data generation workflow: + 1. Prepare data generators from configuration + 2. Write data to the specified destination + Args: + config: DatagenSpec Pydantic object containing table configurations + config_source_name: Name for the configuration source (for logging) + Raises: + RuntimeError: If SparkSession is not available or any step fails + ValueError: If critical errors occur during preparation or writing + """ + logger.info(f"Starting combined data generation and writing for {len(config.tables)} tables") + + try: + # Phase 1: Prepare data generators + prepared_generators_map = self._prepare_data_generators(config, config_source_name) + + if not prepared_generators_map and list(config.tables.keys()): + logger.warning( + "No data generators were successfully prepared, though tables were defined") + return + + # Phase 2: Write data + self.write_prepared_data( + prepared_generators_map, + config.output_destination, + config_source_name + ) + + logger.info( + "Combined data generation and writing completed successfully") + + except Exception as e: + logger.error( + f"Error during combined data generation and writing: {e}") + raise RuntimeError( + f"Error during combined data generation and writing: {e}") from e \ No newline at end of file diff --git a/pydantic_compat.md b/pydantic_compat.md new file mode 100644 index 00000000..abf26e60 --- /dev/null +++ b/pydantic_compat.md @@ -0,0 +1,101 @@ +To write code that works on both Pydantic V1 and V2 and ensures a smooth future migration, you should code against the V1 API but import it through a compatibility shim. This approach uses V1's syntax, which Pydantic V2 can understand via its built-in V1 compatibility layer. + +----- + +### \#\# The Golden Rule: Code to V1, Import via a Shim 💡 + +The core strategy is to **write all your models using Pydantic V1 syntax and features**. You then use a special utility file to handle the imports, which makes your application code completely agnostic to the installed Pydantic version. + +----- + +### \#\# 1. Implement a Compatibility Shim (`compat.py`) + +This is the most critical step. Create a file named `compat.py` in your project that intelligently imports Pydantic components. Your application will import everything from this file instead of directly from `pydantic`. + +```python +# compat.py +# This module acts as a compatibility layer for Pydantic V1 and V2. + +try: + # This will succeed on environments with Pydantic V2.x + # It imports the V1 API that is bundled within V2. + from pydantic.v1 import BaseModel, Field, validator, constr + +except ImportError: + # This will be executed on environments with only Pydantic V1.x + from pydantic import BaseModel, Field, validator, constr + +# In your application code, do this: +# from .compat import BaseModel +# NOT this: +# from pydantic import BaseModel +``` + +----- + +### \#\# 2. Stick to V1 Features and Syntax (Do's and Don'ts) + +By following these rules in your application code, you ensure the logic works on both versions. + +#### **✅ Models and Fields: DO** + + * Use standard `BaseModel` and `Field` for all your data structures. This is the most stable part of the API. + +#### **❌ Models and Fields: DON'T** + + * **Do not use `__root__` models**. This V1 feature was removed in V2 and the compatibility is not perfect. Instead, model the data explicitly, even if it feels redundant. + * **Bad (Avoid):** `class MyList(BaseModel): __root__: list[str]` + * **Good (Compatible):** `class MyList(BaseModel): items: list[str]` + +#### **✅ Configuration: DO** + + * Use the nested `class Config:` for model configuration. This is the V1 way and is fully supported by the V2 compatibility layer. + * **Example:** + ```python + from .compat import BaseModel + + class User(BaseModel): + id: int + full_name: str + + class Config: + orm_mode = True # V2's compatibility layer translates this + allow_population_by_field_name = True + ``` + +#### **❌ Configuration: DON'T** + + * **Do not use the V2 `model_config` dictionary**. This is a V2-only feature. + +#### **✅ Validators and Data Types: DO** + + * Use the standard V1 `@validator`. It's robust and works perfectly across both versions. + * Use V1 constrained types like `constr`, `conint`, `conlist`. + * **Example:** + ```python + from .compat import BaseModel, validator, constr + + class Product(BaseModel): + name: constr(min_length=3) + + @validator("name") + def name_must_be_alpha(cls, v): + if not v.isalpha(): + raise ValueError("Name must be alphabetic") + return v + ``` + +#### **❌ Validators and Data Types: DON'T** + + * **Do not use V2 decorators** like `@field_validator`, `@model_validator`, or `@field_serializer`. + * **Do not use the V2 `Annotated` syntax** for validation (e.g., `Annotated[str, StringConstraints(min_length=2)]`). + +----- + +### \#\# 3. The Easy Migration Path + +When you're finally ready to leave V1 behind and upgrade your code to be V2-native, the process will be straightforward because your code is already consistent: + +1. **Change Imports**: Your first step will be a simple find-and-replace to change all `from .compat import ...` statements to `from pydantic import ...`. +2. **Run a Codelinter**: Tools like **Ruff** have built-in rules that can automatically refactor most of your V1 syntax (like `Config` classes and `@validator`s) to the new V2 syntax. +3. **Manual Refinements**: Address any complex patterns the automated tools couldn't handle, like replacing your `__root__` model alternatives. \ No newline at end of file diff --git a/scratch.md b/scratch.md new file mode 100644 index 00000000..a3afa5c3 --- /dev/null +++ b/scratch.md @@ -0,0 +1,4 @@ +Pydantic Notes +https://docs.databricks.com/aws/en/release-notes/runtime/14.3lts - 1.10.6 +https://docs.databricks.com/aws/en/release-notes/runtime/15.4lts - 1.10.6 +https://docs.databricks.com/aws/en/release-notes/runtime/16.4lts - 2.8.2 (2.20.1 - core) \ No newline at end of file diff --git a/tests/test_spec.py b/tests/test_spec.py deleted file mode 100644 index 2e12cb27..00000000 --- a/tests/test_spec.py +++ /dev/null @@ -1,7 +0,0 @@ -from dbldatagen.spec.generator import DatagenSpec - -def test_spec(): - spec = DatagenSpec(name="test_spec") - assert spec.name == "test_spec" - - diff --git a/tests/test_specs.py b/tests/test_specs.py new file mode 100644 index 00000000..d3c8ab2c --- /dev/null +++ b/tests/test_specs.py @@ -0,0 +1,466 @@ +from dbldatagen.spec.generator_spec import DatagenSpec +import pytest +from dbldatagen.spec.generator_spec import ( + DatagenSpec, + TableDefinition, + ColumnDefinition, + UCSchemaTarget, + FilePathTarget, + ValidationResult +) + +class TestValidationResult: + """Tests for ValidationResult class""" + + def test_empty_result_is_valid(self): + result = ValidationResult() + assert result.is_valid() + assert len(result.errors) == 0 + assert len(result.warnings) == 0 + + def test_result_with_errors_is_invalid(self): + result = ValidationResult() + result.add_error("Test error") + assert not result.is_valid() + assert len(result.errors) == 1 + + def test_result_with_only_warnings_is_valid(self): + result = ValidationResult() + result.add_warning("Test warning") + assert result.is_valid() + assert len(result.warnings) == 1 + + def test_result_string_representation(self): + result = ValidationResult() + result.add_error("Error 1") + result.add_error("Error 2") + result.add_warning("Warning 1") + + result_str = str(result) + assert "✗ Validation failed" in result_str + assert "Errors (2)" in result_str + assert "Error 1" in result_str + assert "Error 2" in result_str + assert "Warnings (1)" in result_str + assert "Warning 1" in result_str + + def test_valid_result_string_representation(self): + result = ValidationResult() + result_str = str(result) + assert "✓ Validation passed successfully" in result_str + + +class TestColumnDefinitionValidation: + """Tests for ColumnDefinition validation""" + + def test_valid_primary_column(self): + col = ColumnDefinition( + name="id", + type="int", + primary=True + ) + assert col.primary + assert col.type == "int" + + def test_primary_column_with_min_max_raises_error(self): + with pytest.raises(ValueError, match="cannot have min/max options"): + ColumnDefinition( + name="id", + type="int", + primary=True, + options={"min": 1, "max": 100} + ) + + def test_primary_column_nullable_raises_error(self): + with pytest.raises(ValueError, match="cannot be nullable"): + ColumnDefinition( + name="id", + type="int", + primary=True, + nullable=True + ) + + def test_primary_column_without_type_raises_error(self): + with pytest.raises(ValueError, match="must have a type defined"): + ColumnDefinition( + name="id", + primary=True + ) + + def test_non_primary_column_without_type(self): + # Should not raise + col = ColumnDefinition( + name="data", + options={"values": ["a", "b", "c"]} + ) + assert col.name == "data" + + +class TestDatagenSpecValidation: + """Tests for DatagenSpec.validate() method""" + + def test_valid_spec_passes_validation(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ + ColumnDefinition(name="id", type="int", primary=True), + ColumnDefinition(name="name", type="string", options={"values": ["Alice", "Bob"]}), + ] + ) + }, + output_destination=UCSchemaTarget(catalog="main", schema_="default") + ) + + result = spec.validate(strict=False) + assert result.is_valid() + assert len(result.errors) == 0 + + def test_empty_tables_raises_error(self): + spec = DatagenSpec(tables={}) + + with pytest.raises(ValueError, match="at least one table"): + spec.validate(strict=True) + + def test_table_without_columns_raises_error(self): + spec = DatagenSpec( + tables={ + "empty_table": TableDefinition( + number_of_rows=100, + columns=[] + ) + } + ) + + with pytest.raises(ValueError, match="must have at least one column"): + spec.validate() + + def test_negative_row_count_raises_error(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=-10, + columns=[ColumnDefinition(name="id", type="int", primary=True)] + ) + } + ) + + with pytest.raises(ValueError, match="invalid number_of_rows"): + spec.validate() + + def test_zero_row_count_raises_error(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=0, + columns=[ColumnDefinition(name="id", type="int", primary=True)] + ) + } + ) + + with pytest.raises(ValueError, match="invalid number_of_rows"): + spec.validate() + + def test_invalid_partitions_raises_error(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + partitions=-5, + columns=[ColumnDefinition(name="id", type="int", primary=True)] + ) + } + ) + + with pytest.raises(ValueError, match="invalid partitions"): + spec.validate() + + def test_duplicate_column_names_raises_error(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ + ColumnDefinition(name="id", type="int", primary=True), + ColumnDefinition(name="duplicate", type="string"), + ColumnDefinition(name="duplicate", type="int"), + ] + ) + } + ) + + with pytest.raises(ValueError, match="duplicate column names"): + spec.validate() + + def test_invalid_base_column_reference_raises_error(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ + ColumnDefinition(name="id", type="int", primary=True), + ColumnDefinition(name="email", type="string", baseColumn="nonexistent"), + ] + ) + } + ) + + with pytest.raises(ValueError, match="does not exist"): + spec.validate() + + def test_circular_dependency_raises_error(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ + ColumnDefinition(name="id", type="int", primary=True), + ColumnDefinition(name="col_a", type="string", baseColumn="col_b"), + ColumnDefinition(name="col_b", type="string", baseColumn="col_c"), + ColumnDefinition(name="col_c", type="string", baseColumn="col_a"), + ] + ) + } + ) + + with pytest.raises(ValueError, match="Circular dependency"): + spec.validate() + + def test_multiple_primary_columns_warning(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ + ColumnDefinition(name="id1", type="int", primary=True), + ColumnDefinition(name="id2", type="int", primary=True), + ] + ) + } + ) + + # In strict mode, warnings cause errors + with pytest.raises(ValueError, match="multiple primary columns"): + spec.validate(strict=True) + + # In non-strict mode, should pass but have warnings + result = spec.validate(strict=False) + assert result.is_valid() + assert len(result.warnings) > 0 + assert any("multiple primary columns" in w for w in result.warnings) + + def test_column_without_type_or_options_warning(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ + ColumnDefinition(name="id", type="int", primary=True), + ColumnDefinition(name="empty_col"), + ] + ) + } + ) + + result = spec.validate(strict=False) + assert result.is_valid() + assert len(result.warnings) > 0 + assert any("No type specified" in w for w in result.warnings) + + def test_no_output_destination_warning(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ColumnDefinition(name="id", type="int", primary=True)] + ) + } + ) + + result = spec.validate(strict=False) + assert result.is_valid() + assert len(result.warnings) > 0 + assert any("No output_destination" in w for w in result.warnings) + + def test_unknown_generator_option_warning(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ColumnDefinition(name="id", type="int", primary=True)] + ) + }, + generator_options={"unknown_option": "value"} + ) + + result = spec.validate(strict=False) + assert result.is_valid() + assert len(result.warnings) > 0 + assert any("Unknown generator option" in w for w in result.warnings) + + def test_multiple_errors_collected(self): + """Test that all errors are collected before raising""" + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=-10, # Error 1 + partitions=0, # Error 2 + columns=[ + ColumnDefinition(name="id", type="int", primary=True), + ColumnDefinition(name="id", type="string"), # Error 3: duplicate + ColumnDefinition(name="email", baseColumn="phone"), # Error 4: nonexistent + ] + ) + } + ) + + with pytest.raises(ValueError) as exc_info: + spec.validate() + + error_msg = str(exc_info.value) + # Should contain all errors + assert "invalid number_of_rows" in error_msg + assert "invalid partitions" in error_msg + assert "duplicate column names" in error_msg + assert "does not exist" in error_msg + + def test_strict_mode_raises_on_warnings(self): + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ColumnDefinition(name="id", type="int", primary=True)] + ) + } + # No output_destination - will generate warning + ) + + # Strict mode should raise + with pytest.raises(ValueError): + spec.validate(strict=True) + + # Non-strict mode should pass + result = spec.validate(strict=False) + assert result.is_valid() + + def test_valid_base_column_chain(self): + """Test that valid baseColumn chains work""" + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ + ColumnDefinition(name="id", type="int", primary=True), + ColumnDefinition(name="code", type="string", baseColumn="id"), + ColumnDefinition(name="hash", type="string", baseColumn="code"), + ] + ) + }, + output_destination=FilePathTarget(base_path="/tmp/data", output_format="parquet") + ) + + result = spec.validate(strict=False) + assert result.is_valid() + + def test_multiple_tables_validation(self): + """Test validation across multiple tables""" + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=100, + columns=[ColumnDefinition(name="id", type="int", primary=True)] + ), + "orders": TableDefinition( + number_of_rows=-50, # Error in second table + columns=[ColumnDefinition(name="order_id", type="int", primary=True)] + ), + "products": TableDefinition( + number_of_rows=200, + columns=[] # Error: no columns + ) + } + ) + + with pytest.raises(ValueError) as exc_info: + spec.validate() + + error_msg = str(exc_info.value) + # Should find errors in both tables + assert "orders" in error_msg + assert "products" in error_msg + + +class TestTargetValidation: + """Tests for output target validation""" + + def test_valid_uc_schema_target(self): + target = UCSchemaTarget(catalog="main", schema_="default") + assert target.catalog == "main" + assert target.schema_ == "default" + + def test_uc_schema_empty_catalog_raises_error(self): + with pytest.raises(ValueError, match="non-empty"): + UCSchemaTarget(catalog="", schema_="default") + + def test_valid_file_path_target(self): + target = FilePathTarget(base_path="/tmp/data", output_format="parquet") + assert target.base_path == "/tmp/data" + assert target.output_format == "parquet" + + def test_file_path_empty_base_path_raises_error(self): + with pytest.raises(ValueError, match="non-empty"): + FilePathTarget(base_path="", output_format="csv") + + def test_file_path_invalid_format_raises_error(self): + with pytest.raises(ValueError): + FilePathTarget(base_path="/tmp/data", output_format="json") + + +class TestValidationIntegration: + """Integration tests for validation""" + + def test_realistic_valid_spec(self): + """Test a realistic, valid specification""" + spec = DatagenSpec( + tables={ + "users": TableDefinition( + number_of_rows=1000, + partitions=4, + columns=[ + ColumnDefinition(name="user_id", type="int", primary=True), + ColumnDefinition(name="username", type="string", options={ + "template": r"\w{8,12}" + }), + ColumnDefinition(name="email", type="string", options={ + "template": r"\w.\w@\w.com" + }), + ColumnDefinition(name="age", type="int", options={ + "min": 18, "max": 99 + }), + ] + ), + "orders": TableDefinition( + number_of_rows=5000, + columns=[ + ColumnDefinition(name="order_id", type="int", primary=True), + ColumnDefinition(name="amount", type="decimal", options={ + "min": 10.0, "max": 1000.0 + }), + ] + ) + }, + output_destination=UCSchemaTarget( + catalog="main", + schema_="synthetic_data" + ), + generator_options={ + "random": True, + "randomSeed": 42 + } + ) + + result = spec.validate(strict=True) + assert result.is_valid() + assert len(result.errors) == 0 + assert len(result.warnings) == 0 \ No newline at end of file From d37de684738cda5e7882a2755e7bbf889911df6f Mon Sep 17 00:00:00 2001 From: Anup Kalburgi Date: Wed, 5 Nov 2025 12:18:14 -0500 Subject: [PATCH 03/20] fixing tests --- dbldatagen/spec/column_spec.py | 55 +++++++++++++ dbldatagen/spec/compat.py | 9 +-- dbldatagen/spec/generator_spec.py | 103 +++++++------------------ dbldatagen/spec/generator_spec_impl.py | 22 +++--- makefile | 2 +- pyproject.toml | 3 +- 6 files changed, 104 insertions(+), 90 deletions(-) create mode 100644 dbldatagen/spec/column_spec.py diff --git a/dbldatagen/spec/column_spec.py b/dbldatagen/spec/column_spec.py new file mode 100644 index 00000000..8fd50496 --- /dev/null +++ b/dbldatagen/spec/column_spec.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import Any, Literal + +from .compat import BaseModel, root_validator + + +DbldatagenBasicType = Literal[ + "string", + "int", + "long", + "float", + "double", + "decimal", + "boolean", + "date", + "timestamp", + "short", + "byte", + "binary", + "integer", + "bigint", + "tinyint", +] +class ColumnDefinition(BaseModel): + name: str + type: DbldatagenBasicType | None = None + primary: bool = False + options: dict[str, Any] | None = None + nullable: bool | None = False + omit: bool | None = False + baseColumn: str | None = "id" + baseColumnType: str | None = "auto" + + @root_validator() + def check_model_constraints(cls, values: dict[str, Any]) -> dict[str, Any]: + """ + Validates constraints across the entire model after individual fields are processed. + """ + is_primary = values.get("primary") + options = values.get("options") or {} # Handle None case + name = values.get("name") + is_nullable = values.get("nullable") + column_type = values.get("type") + + if is_primary: + if "min" in options or "max" in options: + raise ValueError(f"Primary column '{name}' cannot have min/max options.") + + if is_nullable: + raise ValueError(f"Primary column '{name}' cannot be nullable.") + + if column_type is None: + raise ValueError(f"Primary column '{name}' must have a type defined.") + return values diff --git a/dbldatagen/spec/compat.py b/dbldatagen/spec/compat.py index 7c30d57d..dfafe7b1 100644 --- a/dbldatagen/spec/compat.py +++ b/dbldatagen/spec/compat.py @@ -2,13 +2,12 @@ try: # This will succeed on environments with Pydantic V2.x - # It imports the V1 API that is bundled within V2. - from pydantic.v1 import BaseModel, Field, validator, constr - + from pydantic.v1 import BaseModel, Field, constr, root_validator, validator except ImportError: # This will be executed on environments with only Pydantic V1.x - from pydantic import BaseModel, Field, validator, constr, root_validator, field_validator + from pydantic import BaseModel, Field, constr, root_validator, validator # type: ignore[assignment,no-redef] +__all__ = ["BaseModel", "Field", "constr", "root_validator", "validator"] # In your application code, do this: # from .compat import BaseModel # NOT this: @@ -28,4 +27,4 @@ Future-Ready: When you eventually decide to migrate fully to the Pydantic V2 API (to take advantage of its speed and features), you only need to change your application code and your compat.py import statements, making the transition much clearer. -""" \ No newline at end of file +""" diff --git a/dbldatagen/spec/generator_spec.py b/dbldatagen/spec/generator_spec.py index 23afc4a0..ce3ec9ed 100644 --- a/dbldatagen/spec/generator_spec.py +++ b/dbldatagen/spec/generator_spec.py @@ -1,66 +1,25 @@ -from .compat import BaseModel, validator, root_validator, field_validator -from typing import Dict, Optional, Union, Any, Literal, List +from __future__ import annotations + +import logging +from typing import Any, Literal, Union + import pandas as pd -from IPython.display import display, HTML - -DbldatagenBasicType = Literal[ - "string", - "int", - "long", - "float", - "double", - "decimal", - "boolean", - "date", - "timestamp", - "short", - "byte", - "binary", - "integer", - "bigint", - "tinyint", -] - -class ColumnDefinition(BaseModel): - name: str - type: Optional[DbldatagenBasicType] = None - primary: bool = False - options: Optional[Dict[str, Any]] = {} - nullable: Optional[bool] = False - omit: Optional[bool] = False - baseColumn: Optional[str] = "id" - baseColumnType: Optional[str] = "auto" - - @root_validator(skip_on_failure=True) - def check_model_constraints(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """ - Validates constraints across the entire model after individual fields are processed. - """ - is_primary = values.get("primary") - options = values.get("options", {}) - name = values.get("name") - is_nullable = values.get("nullable") - column_type = values.get("type") +from IPython.display import HTML, display - if is_primary: - if "min" in options or "max" in options: - raise ValueError(f"Primary column '{name}' cannot have min/max options.") +from dbldatagen.spec.column_spec import ColumnDefinition - if is_nullable: - raise ValueError(f"Primary column '{name}' cannot be nullable.") +from .compat import BaseModel, validator - if column_type is None: - raise ValueError(f"Primary column '{name}' must have a type defined.") - return values +logger = logging.getLogger(__name__) class UCSchemaTarget(BaseModel): catalog: str schema_: str output_format: str = "delta" # Default to delta for UC Schema - @field_validator("catalog", "schema_", mode="after") - def validate_identifiers(cls, v): # noqa: N805, pylint: disable=no-self-argument + @validator("catalog", "schema_") + def validate_identifiers(cls, v: str) -> str: if not v.strip(): raise ValueError("Identifier must be non-empty.") if not v.isidentifier(): @@ -68,7 +27,7 @@ def validate_identifiers(cls, v): # noqa: N805, pylint: disable=no-self-argumen f"'{v}' is not a basic Python identifier. Ensure validity for Unity Catalog.") return v.strip() - def __str__(self): + def __str__(self) -> str: return f"{self.catalog}.{self.schema_} (Format: {self.output_format}, Type: UC Table)" @@ -76,28 +35,28 @@ class FilePathTarget(BaseModel): base_path: str output_format: Literal["csv", "parquet"] # No default, must be specified - @field_validator("base_path", mode="after") - def validate_base_path(cls, v): # noqa: N805, pylint: disable=no-self-argument + @validator("base_path") + def validate_base_path(cls, v: str) -> str: if not v.strip(): raise ValueError("base_path must be non-empty.") return v.strip() - def __str__(self): + def __str__(self) -> str: return f"{self.base_path} (Format: {self.output_format}, Type: File Path)" class TableDefinition(BaseModel): number_of_rows: int - partitions: Optional[int] = None - columns: List[ColumnDefinition] + partitions: int | None = None + columns: list[ColumnDefinition] class ValidationResult: """Container for validation results with errors and warnings.""" def __init__(self) -> None: - self.errors: List[str] = [] - self.warnings: List[str] = [] + self.errors: list[str] = [] + self.warnings: list[str] = [] def add_error(self, message: str) -> None: """Add an error message.""" @@ -132,16 +91,16 @@ def __str__(self) -> str: return "\n".join(lines) class DatagenSpec(BaseModel): - tables: Dict[str, TableDefinition] - output_destination: Optional[Union[UCSchemaTarget, FilePathTarget]] = None # there is a abstraction, may be we can use that? talk to Greg - generator_options: Optional[Dict[str, Any]] = {} - intended_for_databricks: Optional[bool] = None # May be infered. + tables: dict[str, TableDefinition] + output_destination: Union[UCSchemaTarget, FilePathTarget] | None = None # there is a abstraction, may be we can use that? talk to Greg + generator_options: dict[str, Any] | None = None + intended_for_databricks: bool | None = None # May be infered. def _check_circular_dependencies( self, table_name: str, - columns: List[ColumnDefinition] - ) -> List[str]: + columns: list[ColumnDefinition] + ) -> list[str]: """ Check for circular dependencies in baseColumn references. Returns a list of error messages if circular dependencies are found. @@ -152,13 +111,13 @@ def _check_circular_dependencies( for col in columns: if col.baseColumn and col.baseColumn != "id": # Track the dependency chain - visited = set() + visited: set[str] = set() current = col.name while current: if current in visited: # Found a cycle - cycle_path = " -> ".join(list(visited) + [current]) + cycle_path = " -> ".join([*list(visited), current]) errors.append( f"Table '{table_name}': Circular dependency detected in column '{col.name}': {cycle_path}" ) @@ -182,7 +141,7 @@ def _check_circular_dependencies( return errors - def validate(self, strict: bool = True) -> ValidationResult: + def validate(self, strict: bool = True) -> ValidationResult: # type: ignore[override] """ Validates the entire DatagenSpec configuration. Always runs all validation checks and collects all errors and warnings. @@ -284,7 +243,7 @@ def validate(self, strict: bool = True) -> ValidationResult: "random", "randomSeed", "randomSeedMethod", "verbose", "debug", "seedColumnName" ] - for key in self.generator_options.keys(): + for key in self.generator_options: if key not in known_options: result.add_warning( f"Unknown generator option: '{key}'. " @@ -292,9 +251,7 @@ def validate(self, strict: bool = True) -> ValidationResult: ) # Now that all validations are complete, decide whether to raise - if strict and (result.errors or result.warnings): - raise ValueError(str(result)) - elif not strict and result.errors: + if (strict and (result.errors or result.warnings)) or (not strict and result.errors): raise ValueError(str(result)) return result diff --git a/dbldatagen/spec/generator_spec_impl.py b/dbldatagen/spec/generator_spec_impl.py index a508b1a5..e03e30fb 100644 --- a/dbldatagen/spec/generator_spec_impl.py +++ b/dbldatagen/spec/generator_spec_impl.py @@ -1,11 +1,13 @@ import logging -from typing import Dict, Union import posixpath +from typing import Any, Union -from dbldatagen.spec.generator_spec import TableDefinition from pyspark.sql import SparkSession + import dbldatagen as dg -from .generator_spec import DatagenSpec, UCSchemaTarget, FilePathTarget, ColumnDefinition +from dbldatagen.spec.generator_spec import TableDefinition + +from .generator_spec import ColumnDefinition, DatagenSpec, FilePathTarget, UCSchemaTarget logging.basicConfig( @@ -41,7 +43,7 @@ def __init__(self, spark: SparkSession, app_name: str = "DataGen_ClassBased") -> self.app_name = app_name logger.info("Generator initialized with SparkSession") - def _columnspec_to_datagen_columnspec(self, col_def: ColumnDefinition) -> Dict[str, str]: + def _columnspec_to_datagen_columnspec(self, col_def: ColumnDefinition) -> dict[str, Any]: """ Convert a ColumnDefinition to dbldatagen column specification. Args: @@ -95,7 +97,7 @@ def _prepare_data_generators( self, config: DatagenSpec, config_source_name: str = "PydanticConfig" - ) -> Dict[str, dg.DataGenerator]: + ) -> dict[str, dg.DataGenerator]: """ Prepare DataGenerator specifications for each table based on the configuration. Args: @@ -117,10 +119,10 @@ def _prepare_data_generators( raise RuntimeError( "SparkSession is not available. Cannot prepare data generators") - tables_config: Dict[str, TableDefinition] = config.tables + tables_config: dict[str, TableDefinition] = config.tables global_gen_options = config.generator_options if config.generator_options else {} - prepared_generators: Dict[str, dg.DataGenerator] = {} + prepared_generators: dict[str, dg.DataGenerator] = {} generation_order = list(tables_config.keys()) # This becomes impotant when we get into multitable for table_name in generation_order: @@ -156,7 +158,7 @@ def _prepare_data_generators( def write_prepared_data( self, - prepared_generators: Dict[str, dg.DataGenerator], + prepared_generators: dict[str, dg.DataGenerator], output_destination: Union[UCSchemaTarget, FilePathTarget, None], config_source_name: str = "PydanticConfig", ) -> None: @@ -188,7 +190,7 @@ def write_prepared_data( logger.info( f"Built DataFrame for '{table_name}': {actual_row_count} rows (requested: {requested_rows})") - if actual_row_count == 0 and requested_rows > 0: + if actual_row_count == 0 and requested_rows is not None and requested_rows > 0: logger.warning(f"Table '{table_name}': Requested {requested_rows} rows but built 0") # Write data based on destination type @@ -251,4 +253,4 @@ def generate_and_write_data( logger.error( f"Error during combined data generation and writing: {e}") raise RuntimeError( - f"Error during combined data generation and writing: {e}") from e \ No newline at end of file + f"Error during combined data generation and writing: {e}") from e diff --git a/makefile b/makefile index a5f4486c..df3e5e6e 100644 --- a/makefile +++ b/makefile @@ -8,7 +8,7 @@ clean: .venv/bin/python: pip install hatch - hatch env create test-pydantic.pydantic==1.10.6-v1 + hatch env create dev: .venv/bin/python @hatch run which python diff --git a/pyproject.toml b/pyproject.toml index 304562e2..99be0820 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,7 @@ dependencies = [ "jmespath>=0.10.0", "py4j>=0.10.9", "pickleshare>=0.7.5", + "ipython>=7.32.0", ] python="3.10" @@ -431,7 +432,7 @@ check_untyped_defs = true disallow_untyped_decorators = false no_implicit_optional = true warn_redundant_casts = true -warn_unused_ignores = true +warn_unused_ignores = false warn_no_return = true warn_unreachable = true strict_equality = true From f5214caf8b9859aa64dea3228bdcbcfe7306b34f Mon Sep 17 00:00:00 2001 From: Anup Kalburgi Date: Thu, 6 Nov 2025 10:02:23 -0500 Subject: [PATCH 04/20] changes to make file --- makefile | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/makefile b/makefile index df3e5e6e..a551f964 100644 --- a/makefile +++ b/makefile @@ -3,21 +3,18 @@ all: clean dev lint fmt test clean: - rm -fr .venv clean htmlcov .mypy_cache .pytest_cache .ruff_cache .coverage coverage.xml + rm -fr clean htmlcov .mypy_cache .pytest_cache .ruff_cache .coverage coverage.xml rm -fr **/*.pyc -.venv/bin/python: - pip install hatch - hatch env create - -dev: .venv/bin/python +dev: + @which hatch > /dev/null || pip install hatch @hatch run which python lint: - hatch run verify + hatch run test-pydantic.2.8.2:verify fmt: - hatch run fmt + hatch run test-pydantic.2.8.2:fmt test: hatch run test-pydantic:test From 0a8fa2a29977033c0217b85d57631dd2b5ba4aef Mon Sep 17 00:00:00 2001 From: Anup Kalburgi Date: Thu, 6 Nov 2025 13:06:32 -0500 Subject: [PATCH 05/20] updating docs --- dbldatagen/spec/column_spec.py | 56 +++++- dbldatagen/spec/compat.py | 69 +++++--- dbldatagen/spec/generator_spec.py | 226 ++++++++++++++++++++++--- dbldatagen/spec/generator_spec_impl.py | 187 ++++++++++++++------ 4 files changed, 445 insertions(+), 93 deletions(-) diff --git a/dbldatagen/spec/column_spec.py b/dbldatagen/spec/column_spec.py index 8fd50496..74e9e57f 100644 --- a/dbldatagen/spec/column_spec.py +++ b/dbldatagen/spec/column_spec.py @@ -22,7 +22,48 @@ "bigint", "tinyint", ] +"""Type alias representing supported basic Spark SQL data types for column definitions. + +Includes both standard SQL types (e.g. string, int, double) and Spark-specific type names +(e.g. bigint, tinyint). These types are used in the ColumnDefinition to specify the data type +for generated columns. +""" + + class ColumnDefinition(BaseModel): + """Defines the specification for a single column in a synthetic data table. + + This class encapsulates all the information needed to generate data for a single column, + including its name, type, constraints, and generation options. It supports both primary key + columns and derived columns that can reference other columns. + + :param name: Name of the column to be generated + :param type: Spark SQL data type for the column (e.g., "string", "int", "timestamp"). + If None, type may be inferred from options or baseColumn + :param primary: If True, this column will be treated as a primary key column with unique values. + Primary columns cannot have min/max options and cannot be nullable + :param options: Dictionary of additional options controlling column generation behavior. + Common options include: min, max, step, values, template, distribution, etc. + See dbldatagen documentation for full list of available options + :param nullable: If True, the column may contain NULL values. Primary columns cannot be nullable + :param omit: If True, this column will be generated internally but excluded from the final output. + Useful for intermediate columns used in calculations + :param baseColumn: Name of another column to use as the basis for generating this column's values. + Default is "id" which refers to the internal row identifier + :param baseColumnType: Method for deriving values from the baseColumn. Common values: + "auto" (infer behavior), "hash" (hash the base column values), + "values" (use base column values directly) + + .. note:: + Primary columns have special constraints: + - Must have a type defined + - Cannot have min/max options + - Cannot be nullable + + .. note:: + Columns can be chained via baseColumn references, but circular dependencies + will be caught during validation + """ name: str type: DbldatagenBasicType | None = None primary: bool = False @@ -34,8 +75,19 @@ class ColumnDefinition(BaseModel): @root_validator() def check_model_constraints(cls, values: dict[str, Any]) -> dict[str, Any]: - """ - Validates constraints across the entire model after individual fields are processed. + """Validates constraints across the entire ColumnDefinition model. + + This validator runs after all individual field validators and checks for cross-field + constraints that depend on multiple fields being set. It ensures that primary key + columns meet all necessary requirements and that conflicting options are not specified. + + :param values: Dictionary of all field values for this ColumnDefinition instance + :returns: The validated values dictionary, unmodified if all validations pass + :raises ValueError: If primary column has min/max options, or if primary column is nullable, + or if primary column doesn't have a type defined + + .. note:: + This is a Pydantic root validator that runs automatically during model instantiation """ is_primary = values.get("primary") options = values.get("options") or {} # Handle None case diff --git a/dbldatagen/spec/compat.py b/dbldatagen/spec/compat.py index dfafe7b1..8fe47508 100644 --- a/dbldatagen/spec/compat.py +++ b/dbldatagen/spec/compat.py @@ -1,30 +1,57 @@ -# This module acts as a compatibility layer for Pydantic V1 and V2. +"""Pydantic compatibility layer for supporting both Pydantic V1 and V2. + +This module provides a unified interface for Pydantic functionality that works across both +Pydantic V1.x and V2.x versions. It ensures that the dbldatagen spec API works in multiple +environments without requiring specific Pydantic version installations. + +The module exports a consistent Pydantic V1-compatible API regardless of which version is installed: + +- **BaseModel**: Base class for all Pydantic models +- **Field**: Field definition with metadata and validation +- **constr**: Constrained string type for validation +- **root_validator**: Decorator for model-level validation +- **validator**: Decorator for field-level validation + +Usage in other modules: + Always import from this compat module, not directly from pydantic:: + + # Correct + from .compat import BaseModel, validator + + # Incorrect - don't do this + from pydantic import BaseModel, validator + +Environment Support: + - **Pydantic V2.x environments**: Imports from pydantic.v1 compatibility layer + - **Pydantic V1.x environments**: Imports directly from pydantic package + - **Databricks runtimes**: Works with pre-installed Pydantic versions without conflicts + +.. note:: + This approach is inspired by FastAPI's compatibility layer: + https://github.com/fastapi/fastapi/blob/master/fastapi/_compat.py + +Benefits: + - **No Installation Required**: Works with whatever Pydantic version is available + - **Single Codebase**: One set of code works across both Pydantic versions + - **Environment Agnostic**: Application code doesn't need to know which version is installed + - **Future-Ready**: Easy migration path to Pydantic V2 API when ready + - **Databricks Compatible**: Avoids conflicts with pre-installed libraries + +Future Migration: + When ready to migrate to native Pydantic V2 API: + 1. Update application code to use V2 patterns + 2. Modify this compat.py to import from native V2 locations + 3. Test in both environments + 4. Deploy incrementally +""" try: # This will succeed on environments with Pydantic V2.x + # Pydantic V2 provides a v1 compatibility layer for backwards compatibility from pydantic.v1 import BaseModel, Field, constr, root_validator, validator except ImportError: # This will be executed on environments with only Pydantic V1.x + # Import directly from pydantic since v1 subpackage doesn't exist from pydantic import BaseModel, Field, constr, root_validator, validator # type: ignore[assignment,no-redef] __all__ = ["BaseModel", "Field", "constr", "root_validator", "validator"] -# In your application code, do this: -# from .compat import BaseModel -# NOT this: -# from pydantic import BaseModel - -# FastAPI Notes -# https://github.com/fastapi/fastapi/blob/master/fastapi/_compat.py - - -""" -## Why This Approach -No Installation Required: It directly addresses your core requirement. -You don't need to %pip install anything, which avoids conflicts with the pre-installed libraries on Databricks. -Single Codebase: You maintain one set of code that is guaranteed to work with the Pydantic V1 API, which is available in both runtimes. - -Environment Agnostic: Your application code in models.py has no idea which version of Pydantic is actually installed. The compat.py module handles that complexity completely. - -Future-Ready: When you eventually decide to migrate fully to the Pydantic V2 API (to take advantage of its speed and features), -you only need to change your application code and your compat.py import statements, making the transition much clearer. -""" diff --git a/dbldatagen/spec/generator_spec.py b/dbldatagen/spec/generator_spec.py index ce3ec9ed..d0a750db 100644 --- a/dbldatagen/spec/generator_spec.py +++ b/dbldatagen/spec/generator_spec.py @@ -13,13 +13,44 @@ logger = logging.getLogger(__name__) + class UCSchemaTarget(BaseModel): + """Defines a Unity Catalog schema as the output destination for generated data. + + This class represents a Unity Catalog location (catalog.schema) where generated tables + will be written. Unity Catalog is Databricks' unified governance solution for data and AI. + + :param catalog: Unity Catalog catalog name where tables will be written + :param schema_: Unity Catalog schema (database) name within the catalog + :param output_format: Data format for table storage. Defaults to "delta" which is the + recommended format for Unity Catalog tables + + .. note:: + The schema parameter is named `schema_` (with underscore) to avoid conflict with + Python's built-in schema keyword and Pydantic functionality + + .. note:: + Tables will be written to the location: `{catalog}.{schema_}.{table_name}` + """ catalog: str schema_: str output_format: str = "delta" # Default to delta for UC Schema @validator("catalog", "schema_") def validate_identifiers(cls, v: str) -> str: + """Validates that catalog and schema names are valid identifiers. + + Ensures the identifier is non-empty and follows Python identifier conventions. + Issues a warning if the identifier is not a basic Python identifier, as this may + cause issues with Unity Catalog. + + :param v: The identifier string to validate (catalog or schema name) + :returns: The validated and stripped identifier string + :raises ValueError: If the identifier is empty or contains only whitespace + + .. note:: + This is a Pydantic field validator that runs automatically during model instantiation + """ if not v.strip(): raise ValueError("Identifier must be non-empty.") if not v.isidentifier(): @@ -28,50 +59,130 @@ def validate_identifiers(cls, v: str) -> str: return v.strip() def __str__(self) -> str: + """Returns a human-readable string representation of the Unity Catalog target. + + :returns: Formatted string showing catalog, schema, format and type + """ return f"{self.catalog}.{self.schema_} (Format: {self.output_format}, Type: UC Table)" class FilePathTarget(BaseModel): + """Defines a file system path as the output destination for generated data. + + This class represents a file system location where generated tables will be written + as files. Each table will be written to a subdirectory within the base path. + + :param base_path: Base file system path where table data files will be written. + Each table will be written to {base_path}/{table_name}/ + :param output_format: File format for data storage. Must be either "csv" or "parquet". + No default value - must be explicitly specified + + .. note:: + Unlike UCSchemaTarget, this requires an explicit output_format with no default + + .. note:: + The base_path can be a local file system path, DBFS path, or cloud storage path + (e.g., s3://, gs://, abfs://) depending on your environment + """ base_path: str output_format: Literal["csv", "parquet"] # No default, must be specified @validator("base_path") def validate_base_path(cls, v: str) -> str: + """Validates that the base path is non-empty. + + :param v: The base path string to validate + :returns: The validated and stripped base path string + :raises ValueError: If the base path is empty or contains only whitespace + + .. note:: + This is a Pydantic field validator that runs automatically during model instantiation + """ if not v.strip(): raise ValueError("base_path must be non-empty.") return v.strip() def __str__(self) -> str: + """Returns a human-readable string representation of the file path target. + + :returns: Formatted string showing base path, format and type + """ return f"{self.base_path} (Format: {self.output_format}, Type: File Path)" class TableDefinition(BaseModel): + """Defines the complete specification for a single synthetic data table. + + This class encapsulates all the information needed to generate a table of synthetic data, + including the number of rows, partitioning, and column specifications. + + :param number_of_rows: Total number of data rows to generate for this table. + Must be a positive integer + :param partitions: Number of Spark partitions to use when generating data. + If None, defaults to Spark's default parallelism setting. + More partitions can improve generation speed for large datasets + :param columns: List of ColumnDefinition objects specifying the columns to generate + in this table. At least one column must be specified + + .. note:: + Setting an appropriate number of partitions can significantly impact generation performance. + As a rule of thumb, use 2-4 partitions per CPU core available in your Spark cluster + + .. note:: + Column order in the list determines the order of columns in the generated output + """ number_of_rows: int partitions: int | None = None columns: list[ColumnDefinition] class ValidationResult: - """Container for validation results with errors and warnings.""" + """Container for validation results that collects errors and warnings during spec validation. + + This class accumulates validation issues found while checking a DatagenSpec configuration. + It distinguishes between errors (which prevent data generation) and warnings (which + indicate potential issues but don't block generation). + + .. note:: + Validation passes if there are no errors, even if warnings are present + """ def __init__(self) -> None: + """Initialize an empty ValidationResult with no errors or warnings.""" self.errors: list[str] = [] self.warnings: list[str] = [] def add_error(self, message: str) -> None: - """Add an error message.""" + """Add an error message to the validation results. + + Errors indicate critical issues that will prevent successful data generation. + + :param message: Descriptive error message explaining the validation failure + """ self.errors.append(message) def add_warning(self, message: str) -> None: - """Add a warning message.""" + """Add a warning message to the validation results. + + Warnings indicate potential issues or non-optimal configurations that may affect + data generation but won't prevent it from completing. + + :param message: Descriptive warning message explaining the potential issue + """ self.warnings.append(message) def is_valid(self) -> bool: - """Returns True if there are no errors.""" + """Check if validation passed without errors. + + :returns: True if there are no errors (warnings are allowed), False otherwise + """ return len(self.errors) == 0 def __str__(self) -> str: - """String representation of validation results.""" + """Generate a formatted string representation of all validation results. + + :returns: Multi-line string containing formatted errors and warnings with counts + """ lines = [] if self.is_valid(): lines.append("✓ Validation passed successfully") @@ -91,6 +202,34 @@ def __str__(self) -> str: return "\n".join(lines) class DatagenSpec(BaseModel): + """Top-level specification for synthetic data generation across one or more tables. + + This is the main configuration class for the dbldatagen spec-based API. It defines all tables + to be generated, where the output should be written, and global generation options. + + :param tables: Dictionary mapping table names to their TableDefinition specifications. + Keys are the table names that will be used in the output destination + :param output_destination: Target location for generated data. Can be either a + UCSchemaTarget (Unity Catalog) or FilePathTarget (file system). + If None, data will be generated but not persisted + :param generator_options: Dictionary of global options affecting data generation behavior. + Common options include: + - random: Enable random data generation + - randomSeed: Seed for reproducible random generation + - randomSeedMethod: Method for computing random seeds + - verbose: Enable verbose logging + - debug: Enable debug logging + - seedColumnName: Name of internal seed column + :param intended_for_databricks: Flag indicating if this spec is designed for Databricks. + May be automatically inferred based on configuration + + .. note:: + Call the validate() method before using this spec to ensure configuration is correct + + .. note:: + Multiple tables can share the same DatagenSpec and will be generated in the order + they appear in the tables dictionary + """ tables: dict[str, TableDefinition] output_destination: Union[UCSchemaTarget, FilePathTarget] | None = None # there is a abstraction, may be we can use that? talk to Greg generator_options: dict[str, Any] | None = None @@ -101,9 +240,19 @@ def _check_circular_dependencies( table_name: str, columns: list[ColumnDefinition] ) -> list[str]: - """ - Check for circular dependencies in baseColumn references. - Returns a list of error messages if circular dependencies are found. + """Check for circular dependencies in baseColumn references within a table. + + Analyzes column dependencies to detect cycles where columns reference each other + in a circular manner (e.g., col A depends on col B, col B depends on col A). + Such circular dependencies would make data generation impossible. + + :param table_name: Name of the table being validated (used in error messages) + :param columns: List of ColumnDefinition objects to check for circular dependencies + :returns: List of error message strings describing any circular dependencies found. + Empty list if no circular dependencies exist + + .. note:: + This method performs a graph traversal to detect cycles in the dependency chain """ errors = [] column_map = {col.name: col for col in columns} @@ -142,20 +291,35 @@ def _check_circular_dependencies( return errors def validate(self, strict: bool = True) -> ValidationResult: # type: ignore[override] - """ - Validates the entire DatagenSpec configuration. - Always runs all validation checks and collects all errors and warnings. - - Args: - strict: If True, raises ValueError if any errors or warnings are found. - If False, only raises ValueError if errors (not warnings) are found. - - Returns: - ValidationResult object containing all errors and warnings found. - - Raises: - ValueError: If validation fails based on strict mode setting. - The exception message contains all errors and warnings. + """Validate the entire DatagenSpec configuration comprehensively. + + This method performs extensive validation of the entire spec, including: + - Ensuring at least one table is defined + - Validating each table has columns and positive row counts + - Checking for duplicate column names within tables + - Verifying baseColumn references point to existing columns + - Detecting circular dependencies in baseColumn chains + - Validating primary key constraints + - Checking output destination configuration + - Validating generator options + + All validation checks are performed regardless of whether errors are found, allowing + you to see all issues at once rather than fixing them one at a time. + + :param strict: Controls validation failure behavior: + - If True: Raises ValueError for any errors OR warnings found + - If False: Only raises ValueError for errors (warnings are tolerated) + :returns: ValidationResult object containing all collected errors and warnings, + even if an exception is raised + :raises ValueError: If validation fails based on strict mode setting. + The exception message contains the formatted ValidationResult + + .. note:: + It's recommended to call validate() before attempting to generate data to catch + configuration issues early + + .. note:: + Use strict=False during development to see warnings without blocking generation """ result = ValidationResult() @@ -258,6 +422,24 @@ def validate(self, strict: bool = True) -> ValidationResult: # type: ignore[ove def display_all_tables(self) -> None: + """Display a formatted view of all table definitions in the spec. + + This method provides a user-friendly visualization of the spec configuration, showing + each table's structure and the output destination. It's designed for use in Jupyter + notebooks and will render HTML output when available. + + For each table, displays: + - Table name + - Output destination (or warning if not configured) + - DataFrame showing all columns with their properties + + .. note:: + This method uses IPython.display.HTML when available, falling back to plain text + output in non-notebook environments + + .. note:: + This is intended for interactive exploration and debugging of spec configurations + """ for table_name, table_def in self.tables.items(): print(f"Table: {table_name}") diff --git a/dbldatagen/spec/generator_spec_impl.py b/dbldatagen/spec/generator_spec_impl.py index e03e30fb..fc53863e 100644 --- a/dbldatagen/spec/generator_spec_impl.py +++ b/dbldatagen/spec/generator_spec_impl.py @@ -21,18 +21,35 @@ class Generator: - """ - Main data generation orchestrator that handles configuration, preparation, and writing of data. + """Main orchestrator for generating synthetic data from DatagenSpec configurations. + + This class provides the primary interface for the spec-based data generation API. It handles + the complete lifecycle of data generation: + 1. Converting spec configurations into dbldatagen DataGenerator objects + 2. Building the actual data as Spark DataFrames + 3. Writing the data to specified output destinations (Unity Catalog or file system) + + The Generator encapsulates all the complexity of translating declarative specs into + executable data generation plans, allowing users to focus on what data they want rather + than how to generate it. + + :param spark: Active SparkSession to use for data generation + :param app_name: Application name used in logging and tracking. Defaults to "DataGen_ClassBased" + + .. note:: + The Generator requires an active SparkSession. On Databricks, you can use the pre-configured + `spark` variable. For local development, create a SparkSession first + + .. note:: + The same Generator instance can be reused to generate multiple different specs """ def __init__(self, spark: SparkSession, app_name: str = "DataGen_ClassBased") -> None: - """ - Initialize the Generator with a SparkSession. - Args: - spark: An existing SparkSession instance - app_name: Application name for logging purposes - Raises: - RuntimeError: If spark is None + """Initialize the Generator with a SparkSession. + + :param spark: An active SparkSession instance to use for data generation operations + :param app_name: Application name for logging and identification purposes + :raises RuntimeError: If spark is None or not properly initialized """ if not spark: logger.error( @@ -44,12 +61,26 @@ def __init__(self, spark: SparkSession, app_name: str = "DataGen_ClassBased") -> logger.info("Generator initialized with SparkSession") def _columnspec_to_datagen_columnspec(self, col_def: ColumnDefinition) -> dict[str, Any]: - """ - Convert a ColumnDefinition to dbldatagen column specification. - Args: - col_def: ColumnDefinition object containing column configuration - Returns: - Dictionary containing dbldatagen column specification + """Convert a ColumnDefinition spec into dbldatagen DataGenerator column arguments. + + This internal method translates the declarative ColumnDefinition format into the + keyword arguments expected by dbldatagen's withColumn() method. It handles special + cases like primary keys, nullable columns, and omitted columns. + + Primary key columns receive special treatment: + - Automatically use the internal ID column as their base + - String primary keys use hash-based generation + - Numeric primary keys maintain sequential values + + :param col_def: ColumnDefinition object from a DatagenSpec + :returns: Dictionary of keyword arguments suitable for DataGenerator.withColumn() + + .. note:: + This is an internal method not intended for direct use by end users + + .. note:: + Conflicting options for primary keys (like min/max, values, expr) will generate + warnings but won't prevent generation - the primary key behavior takes precedence """ col_name = col_def.name col_type = col_def.type @@ -98,17 +129,33 @@ def _prepare_data_generators( config: DatagenSpec, config_source_name: str = "PydanticConfig" ) -> dict[str, dg.DataGenerator]: - """ - Prepare DataGenerator specifications for each table based on the configuration. - Args: - config: DatagenSpec Pydantic object containing table configurations - config_source_name: Name for the configuration source (for logging) - Returns: - Dictionary mapping table names to their configured dbldatagen.DataGenerator objects - Raises: - RuntimeError: If SparkSession is not available - ValueError: If any table preparation fails - Exception: If any unexpected error occurs during preparation + """Prepare DataGenerator objects for all tables defined in the spec. + + This internal method is the first phase of data generation. It processes the DatagenSpec + and creates configured dbldatagen.DataGenerator objects for each table, but does not + yet build the actual data. Each table's definition is converted into a DataGenerator + with all its columns configured. + + The method: + 1. Iterates through all tables in the spec + 2. Creates a DataGenerator for each table with appropriate row count and partitioning + 3. Adds all columns to each DataGenerator using withColumn() + 4. Applies global generator options + 5. Returns the prepared generators ready for building + + :param config: DatagenSpec containing table definitions and configuration + :param config_source_name: Descriptive name for the config source, used in logging + and DataGenerator naming + :returns: Dictionary mapping table names to their prepared DataGenerator instances + :raises RuntimeError: If SparkSession is not available or if any table preparation fails + :raises ValueError: If table configuration is invalid (should be caught by validate() first) + + .. note:: + This is an internal method. Use generate_and_write_data() for the complete workflow + + .. note:: + Preparation is separate from building to allow inspection and modification of + DataGenerators before data generation begins """ logger.info( f"Preparing data generators for {len(config.tables)} tables") @@ -162,17 +209,34 @@ def write_prepared_data( output_destination: Union[UCSchemaTarget, FilePathTarget, None], config_source_name: str = "PydanticConfig", ) -> None: - """ - Write data from prepared generators to the specified output destination. - - Args: - prepared_generators: Dictionary of prepared DataGenerator objects - output_destination: Target destination for data output - config_source_name: Name for the configuration source (for logging) - - Raises: - RuntimeError: If any table write fails - ValueError: If output destination is not properly configured + """Build and write data from prepared generators to the specified output destination. + + This method handles the second phase of data generation: taking prepared DataGenerator + objects, building them into actual Spark DataFrames, and writing the results to the + configured output location. + + The method: + 1. Iterates through all prepared generators + 2. Builds each generator into a DataFrame using build() + 3. Writes the DataFrame to the appropriate destination: + - For FilePathTarget: Writes to {base_path}/{table_name}/ in specified format + - For UCSchemaTarget: Writes to {catalog}.{schema}.{table_name} as managed table + 4. Logs row counts and write locations + + :param prepared_generators: Dictionary mapping table names to DataGenerator objects + (typically from _prepare_data_generators()) + :param output_destination: Target location for output. Can be UCSchemaTarget, + FilePathTarget, or None (no write, data generated only) + :param config_source_name: Descriptive name for the config source, used in logging + :raises RuntimeError: If DataFrame building or writing fails for any table + :raises ValueError: If output destination type is not recognized + + .. note:: + If output_destination is None, data is generated but not persisted anywhere. + This can be useful for testing or when you want to process the data in-memory + + .. note:: + Writing uses "overwrite" mode, so existing tables/files will be replaced """ logger.info("Starting data writing phase") @@ -216,17 +280,44 @@ def generate_and_write_data( config: DatagenSpec, config_source_name: str = "PydanticConfig" ) -> None: - """ - Combined method to prepare data generators and write data in one operation. - This method orchestrates the complete data generation workflow: - 1. Prepare data generators from configuration - 2. Write data to the specified destination - Args: - config: DatagenSpec Pydantic object containing table configurations - config_source_name: Name for the configuration source (for logging) - Raises: - RuntimeError: If SparkSession is not available or any step fails - ValueError: If critical errors occur during preparation or writing + """Execute the complete data generation workflow from spec to output. + + This is the primary high-level method for generating data from a DatagenSpec. It + orchestrates the entire process in one call, handling both preparation and writing phases. + + The complete workflow: + 1. Validates that the config is properly structured (you should call config.validate() first) + 2. Converts the spec into DataGenerator objects for each table + 3. Builds the DataFrames by executing the generation logic + 4. Writes the results to the configured output destination + 5. Logs progress and completion status + + This method is the recommended entry point for most use cases. For more control over + the generation process, use _prepare_data_generators() and write_prepared_data() separately. + + :param config: DatagenSpec object defining tables, columns, and output destination. + Should be validated with config.validate() before calling this method + :param config_source_name: Descriptive name for the config source, used in logging + and naming DataGenerator instances + :raises RuntimeError: If SparkSession is unavailable, or if preparation or writing fails + :raises ValueError: If the config is invalid (though config.validate() should catch this first) + + .. note:: + It's strongly recommended to call config.validate() before this method to catch + configuration errors early with better error messages + + .. note:: + Generation is performed sequentially: table1 is fully generated and written before + table2 begins. For multi-table generation with dependencies, the order matters + + Example: + >>> spec = DatagenSpec( + ... tables={"users": user_table_def}, + ... output_destination=UCSchemaTarget(catalog="main", schema_="test") + ... ) + >>> spec.validate() # Check for errors first + >>> generator = Generator(spark) + >>> generator.generate_and_write_data(spec) """ logger.info(f"Starting combined data generation and writing for {len(config.tables)} tables") From 61f676d6b5d4e78d22f85bf5b0527cd9cb14010e Mon Sep 17 00:00:00 2001 From: Anup Kalburgi Date: Mon, 17 Nov 2025 13:26:32 -0500 Subject: [PATCH 06/20] fixing tests, removing the solved todos, targets to a diff module --- dbldatagen/spec/__init__.py | 39 +++ dbldatagen/spec/compat.py | 3 - dbldatagen/spec/generator_spec.py | 100 +----- dbldatagen/spec/output_targets.py | 101 ++++++ examples/datagen_from_specs/README.md | 144 ++++++++ .../basic_stock_ticker_datagen_spec.py | 316 ++++++++++++++++++ .../basic_user_datagen_spec.py | 109 ++++++ tests/test_datagen_specs.py | 280 ++++++++++++++++ tests/test_datasets_with_specs.py | 212 ++++++++++++ 9 files changed, 1202 insertions(+), 102 deletions(-) create mode 100644 dbldatagen/spec/__init__.py create mode 100644 dbldatagen/spec/output_targets.py create mode 100644 examples/datagen_from_specs/README.md create mode 100644 examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py create mode 100644 examples/datagen_from_specs/basic_user_datagen_spec.py create mode 100644 tests/test_datagen_specs.py create mode 100644 tests/test_datasets_with_specs.py diff --git a/dbldatagen/spec/__init__.py b/dbldatagen/spec/__init__.py new file mode 100644 index 00000000..dde2d8a7 --- /dev/null +++ b/dbldatagen/spec/__init__.py @@ -0,0 +1,39 @@ +"""Pydantic-based specification API for dbldatagen. + +This module provides Pydantic models and specifications for defining data generation +in a type-safe, declarative way. +""" + +# Import only the compat layer by default to avoid triggering Spark/heavy dependencies +from .compat import BaseModel, Field, constr, root_validator, validator + +# Lazy imports for heavy modules - import these explicitly when needed +# from .column_spec import ColumnSpec +# from .generator_spec import GeneratorSpec +# from .generator_spec_impl import GeneratorSpecImpl + +__all__ = [ + "BaseModel", + "Field", + "constr", + "root_validator", + "validator", + "ColumnSpec", + "GeneratorSpec", + "GeneratorSpecImpl", +] + + +def __getattr__(name): + """Lazy import heavy modules to avoid triggering Spark initialization.""" + if name == "ColumnSpec": + from .column_spec import ColumnSpec + return ColumnSpec + elif name == "GeneratorSpec": + from .generator_spec import GeneratorSpec + return GeneratorSpec + elif name == "GeneratorSpecImpl": + from .generator_spec_impl import GeneratorSpecImpl + return GeneratorSpecImpl + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + diff --git a/dbldatagen/spec/compat.py b/dbldatagen/spec/compat.py index 8fe47508..72215c0c 100644 --- a/dbldatagen/spec/compat.py +++ b/dbldatagen/spec/compat.py @@ -32,9 +32,6 @@ Benefits: - **No Installation Required**: Works with whatever Pydantic version is available - - **Single Codebase**: One set of code works across both Pydantic versions - - **Environment Agnostic**: Application code doesn't need to know which version is installed - - **Future-Ready**: Easy migration path to Pydantic V2 API when ready - **Databricks Compatible**: Avoids conflicts with pre-installed libraries Future Migration: diff --git a/dbldatagen/spec/generator_spec.py b/dbldatagen/spec/generator_spec.py index d0a750db..67f460b7 100644 --- a/dbldatagen/spec/generator_spec.py +++ b/dbldatagen/spec/generator_spec.py @@ -9,105 +9,11 @@ from dbldatagen.spec.column_spec import ColumnDefinition from .compat import BaseModel, validator - +from .output_targets import UCSchemaTarget, FilePathTarget logger = logging.getLogger(__name__) -class UCSchemaTarget(BaseModel): - """Defines a Unity Catalog schema as the output destination for generated data. - - This class represents a Unity Catalog location (catalog.schema) where generated tables - will be written. Unity Catalog is Databricks' unified governance solution for data and AI. - - :param catalog: Unity Catalog catalog name where tables will be written - :param schema_: Unity Catalog schema (database) name within the catalog - :param output_format: Data format for table storage. Defaults to "delta" which is the - recommended format for Unity Catalog tables - - .. note:: - The schema parameter is named `schema_` (with underscore) to avoid conflict with - Python's built-in schema keyword and Pydantic functionality - - .. note:: - Tables will be written to the location: `{catalog}.{schema_}.{table_name}` - """ - catalog: str - schema_: str - output_format: str = "delta" # Default to delta for UC Schema - - @validator("catalog", "schema_") - def validate_identifiers(cls, v: str) -> str: - """Validates that catalog and schema names are valid identifiers. - - Ensures the identifier is non-empty and follows Python identifier conventions. - Issues a warning if the identifier is not a basic Python identifier, as this may - cause issues with Unity Catalog. - - :param v: The identifier string to validate (catalog or schema name) - :returns: The validated and stripped identifier string - :raises ValueError: If the identifier is empty or contains only whitespace - - .. note:: - This is a Pydantic field validator that runs automatically during model instantiation - """ - if not v.strip(): - raise ValueError("Identifier must be non-empty.") - if not v.isidentifier(): - logger.warning( - f"'{v}' is not a basic Python identifier. Ensure validity for Unity Catalog.") - return v.strip() - - def __str__(self) -> str: - """Returns a human-readable string representation of the Unity Catalog target. - - :returns: Formatted string showing catalog, schema, format and type - """ - return f"{self.catalog}.{self.schema_} (Format: {self.output_format}, Type: UC Table)" - - -class FilePathTarget(BaseModel): - """Defines a file system path as the output destination for generated data. - - This class represents a file system location where generated tables will be written - as files. Each table will be written to a subdirectory within the base path. - - :param base_path: Base file system path where table data files will be written. - Each table will be written to {base_path}/{table_name}/ - :param output_format: File format for data storage. Must be either "csv" or "parquet". - No default value - must be explicitly specified - - .. note:: - Unlike UCSchemaTarget, this requires an explicit output_format with no default - - .. note:: - The base_path can be a local file system path, DBFS path, or cloud storage path - (e.g., s3://, gs://, abfs://) depending on your environment - """ - base_path: str - output_format: Literal["csv", "parquet"] # No default, must be specified - - @validator("base_path") - def validate_base_path(cls, v: str) -> str: - """Validates that the base path is non-empty. - - :param v: The base path string to validate - :returns: The validated and stripped base path string - :raises ValueError: If the base path is empty or contains only whitespace - - .. note:: - This is a Pydantic field validator that runs automatically during model instantiation - """ - if not v.strip(): - raise ValueError("base_path must be non-empty.") - return v.strip() - - def __str__(self) -> str: - """Returns a human-readable string representation of the file path target. - - :returns: Formatted string showing base path, format and type - """ - return f"{self.base_path} (Format: {self.output_format}, Type: File Path)" class TableDefinition(BaseModel): @@ -342,7 +248,6 @@ def validate(self, strict: bool = True) -> ValidationResult: # type: ignore[ove ) # Check partitions if specified - #TODO: though this can be a model field check, we are checking here so that one can correct # Can we find a way to use the default way? if table_def.partitions is not None and table_def.partitions <= 0: result.add_error( @@ -351,7 +256,6 @@ def validate(self, strict: bool = True) -> ValidationResult: # type: ignore[ove ) # Check for duplicate column names - # TODO: Not something possible if we right model, recheck column_names = [col.name for col in table_def.columns] duplicates = [name for name in set(column_names) if column_names.count(name) > 1] if duplicates: @@ -361,8 +265,6 @@ def validate(self, strict: bool = True) -> ValidationResult: # type: ignore[ove # Build column map for reference checking column_map = {col.name: col for col in table_def.columns} - - # TODO: Check baseColumn references, this is tricky? check the dbldefaults for col in table_def.columns: if col.baseColumn and col.baseColumn != "id": if col.baseColumn not in column_map: diff --git a/dbldatagen/spec/output_targets.py b/dbldatagen/spec/output_targets.py new file mode 100644 index 00000000..f9f51194 --- /dev/null +++ b/dbldatagen/spec/output_targets.py @@ -0,0 +1,101 @@ +from .compat import BaseModel, validator +from typing import Literal +import logging + +logger = logging.getLogger(__name__) + + +class UCSchemaTarget(BaseModel): + """Defines a Unity Catalog schema as the output destination for generated data. + + This class represents a Unity Catalog location (catalog.schema) where generated tables + will be written. Unity Catalog is Databricks' unified governance solution for data and AI. + + :param catalog: Unity Catalog catalog name where tables will be written + :param schema_: Unity Catalog schema (database) name within the catalog + :param output_format: Data format for table storage. Defaults to "delta" which is the + recommended format for Unity Catalog tables + + .. note:: + The schema parameter is named `schema_` (with underscore) to avoid conflict with + Python's built-in schema keyword and Pydantic functionality + + .. note:: + Tables will be written to the location: `{catalog}.{schema_}.{table_name}` + """ + catalog: str + schema_: str + output_format: str = "delta" # Default to delta for UC Schema + + @validator("catalog", "schema_") + def validate_identifiers(cls, v: str) -> str: + """Validates that catalog and schema names are valid identifiers. + + Ensures the identifier is non-empty and follows Python identifier conventions. + Issues a warning if the identifier is not a basic Python identifier, as this may + cause issues with Unity Catalog. + + :param v: The identifier string to validate (catalog or schema name) + :returns: The validated and stripped identifier string + :raises ValueError: If the identifier is empty or contains only whitespace + + .. note:: + This is a Pydantic field validator that runs automatically during model instantiation + """ + if not v.strip(): + raise ValueError("Identifier must be non-empty.") + if not v.isidentifier(): + logger.warning( + f"'{v}' is not a basic Python identifier. Ensure validity for Unity Catalog.") + return v.strip() + + def __str__(self) -> str: + """Returns a human-readable string representation of the Unity Catalog target. + + :returns: Formatted string showing catalog, schema, format and type + """ + return f"{self.catalog}.{self.schema_} (Format: {self.output_format}, Type: UC Table)" + + +class FilePathTarget(BaseModel): + """Defines a file system path as the output destination for generated data. + + This class represents a file system location where generated tables will be written + as files. Each table will be written to a subdirectory within the base path. + + :param base_path: Base file system path where table data files will be written. + Each table will be written to {base_path}/{table_name}/ + :param output_format: File format for data storage. Must be either "csv" or "parquet". + No default value - must be explicitly specified + + .. note:: + Unlike UCSchemaTarget, this requires an explicit output_format with no default + + .. note:: + The base_path can be a local file system path, DBFS path, or cloud storage path + (e.g., s3://, gs://, abfs://) depending on your environment + """ + base_path: str + output_format: Literal["csv", "parquet"] # No default, must be specified + + @validator("base_path") + def validate_base_path(cls, v: str) -> str: + """Validates that the base path is non-empty. + + :param v: The base path string to validate + :returns: The validated and stripped base path string + :raises ValueError: If the base path is empty or contains only whitespace + + .. note:: + This is a Pydantic field validator that runs automatically during model instantiation + """ + if not v.strip(): + raise ValueError("base_path must be non-empty.") + return v.strip() + + def __str__(self) -> str: + """Returns a human-readable string representation of the file path target. + + :returns: Formatted string showing base path, format and type + """ + return f"{self.base_path} (Format: {self.output_format}, Type: File Path)" diff --git a/examples/datagen_from_specs/README.md b/examples/datagen_from_specs/README.md new file mode 100644 index 00000000..31e6bc9e --- /dev/null +++ b/examples/datagen_from_specs/README.md @@ -0,0 +1,144 @@ +# Dataset Specifications with Pydantic + +This module provides Pydantic model specifications for common datasets available in `dbldatagen.datasets`. These models can be used for type validation, API schemas, and documentation. + +## Available Models + +### BasicUser +Represents user data with customer information: +- `customer_id`: Unique customer identifier (integer >= 1000000) +- `name`: Customer name (string) +- `email`: Email address (string) +- `ip_addr`: IP address (string) +- `phone`: Phone number (string) + +### BasicStockTicker +Represents stock ticker time-series data: +- `symbol`: Stock ticker symbol (string, 1-10 characters) +- `post_date`: Trading date (date) +- `open`: Opening price (Decimal >= 0, 2 decimal places) +- `close`: Closing price (Decimal >= 0, 2 decimal places) +- `high`: Highest price (Decimal >= 0, 2 decimal places) +- `low`: Lowest price (Decimal >= 0, 2 decimal places) +- `adj_close`: Adjusted closing price (Decimal >= 0, 2 decimal places) +- `volume`: Trading volume (integer >= 0) + +## Usage + +### Basic Usage + +```python +from dbldatagen.datasets_with_specs import BasicUser, BasicStockTicker +from datetime import date +from decimal import Decimal + +# Create a user instance +user = BasicUser( + customer_id=1234567890, + name="John Doe", + email="john.doe@example.com", + ip_addr="192.168.1.100", + phone="(555)-123-4567" +) + +# Create a stock ticker instance +ticker = BasicStockTicker( + symbol="AAPL", + post_date=date(2024, 10, 15), + open=Decimal("150.25"), + close=Decimal("152.50"), + high=Decimal("153.75"), + low=Decimal("149.80"), + adj_close=Decimal("152.35"), + volume=2500000 +) +``` + +### Validation + +The models automatically validate data: + +```python +# This will raise a validation error (customer_id too small) +try: + user = BasicUser( + customer_id=100, # Must be >= 1000000 + name="Jane Doe", + email="jane@example.com", + ip_addr="10.0.0.1", + phone="555-1234" + ) +except ValidationError as e: + print(f"Validation failed: {e}") +``` + +### Serialization + +```python +# Convert to dictionary +user_dict = user.dict() + +# Convert to JSON +user_json = user.json() + +# Parse from JSON +user_from_json = BasicUser.parse_raw(user_json) + +# Get JSON schema +schema = BasicUser.schema_json(indent=2) +``` + +### Integration with FastAPI + +```python +from fastapi import FastAPI +from dbldatagen.datasets_with_specs import BasicUser + +app = FastAPI() + +@app.post("/users/") +async def create_user(user: BasicUser): + # FastAPI will automatically validate the request body + return {"user_id": user.customer_id, "name": user.name} +``` + +### Integration with Pandas + +```python +import pandas as pd +from dbldatagen.datasets_with_specs import BasicStockTicker + +# Create DataFrame from Pydantic models +tickers = [ + BasicStockTicker(...), + BasicStockTicker(...), +] + +df = pd.DataFrame([ticker.dict() for ticker in tickers]) +``` + +## Correspondence with Data Generators + +These models correspond to the data generated by the providers in `dbldatagen.datasets`: + +- `BasicUser` ↔ `BasicUserProvider` (dataset name: "basic/user") +- `BasicStockTicker` ↔ `BasicStockTickerProvider` (dataset name: "basic/stock_ticker") + +The Pydantic models define the schema and validation rules, while the providers generate the actual data using the dbldatagen framework. + +## Benefits + +1. **Type Safety**: Catch type errors at development time +2. **Validation**: Automatic data validation with detailed error messages +3. **Documentation**: Self-documenting code with field descriptions +4. **API Integration**: Direct integration with FastAPI and other Pydantic-based frameworks +5. **Schema Generation**: Generate JSON schemas for documentation and code generation +6. **IDE Support**: Better autocomplete and type hints in IDEs + +## Requirements + +- Python 3.8+ +- Pydantic 1.x or 2.x (with v1 compatibility layer) + +The module uses the compatibility layer in `dbldatagen.spec.compat` to work with both Pydantic v1 and v2. + diff --git a/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py b/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py new file mode 100644 index 00000000..5cc18edd --- /dev/null +++ b/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py @@ -0,0 +1,316 @@ +"""DatagenSpec for Basic Stock Ticker Dataset. + +This module defines a declarative Pydantic-based specification for generating +the basic stock ticker dataset, corresponding to the BasicStockTickerProvider. +""" + +from random import random + +from dbldatagen.spec.generator_spec import DatagenSpec, TableDefinition +from dbldatagen.spec.column_spec import ColumnDefinition + + +def create_basic_stock_ticker_spec( + number_of_rows: int = 100000, + partitions: int | None = None, + num_symbols: int = 100, + start_date: str = "2024-10-01" +) -> DatagenSpec: + """Create a DatagenSpec for basic stock ticker data generation. + + This function creates a declarative specification matching the data generated + by BasicStockTickerProvider in the datasets module. It generates time-series + stock data with OHLC (Open, High, Low, Close) values, adjusted close, and volume. + + Args: + number_of_rows: Total number of rows to generate (default: 100,000) + partitions: Number of Spark partitions to use (default: auto-computed) + num_symbols: Number of unique stock ticker symbols to generate (default: 100) + start_date: Starting date for stock data in 'YYYY-MM-DD' format (default: "2024-10-01") + + Returns: + DatagenSpec configured for basic stock ticker data generation + + Example: + >>> spec = create_basic_stock_ticker_spec( + ... number_of_rows=10000, + ... num_symbols=50, + ... start_date="2024-01-01" + ... ) + >>> spec.validate() + >>> # Use with GeneratorSpecImpl to generate data + + Note: + The stock prices use a growth model with volatility to simulate realistic + price movements over time. Each symbol gets its own growth rate and volatility. + """ + # Generate random values for start_value, growth_rate, and volatility + # These need to be pre-computed for the values option + num_value_sets = max(1, int(num_symbols / 10)) + start_values = [1.0 + 199.0 * random() for _ in range(num_value_sets)] + growth_rates = [-0.1 + 0.35 * random() for _ in range(num_value_sets)] + volatility_values = [0.0075 * random() for _ in range(num_value_sets)] + + columns = [ + # Symbol ID (numeric identifier for symbol) + ColumnDefinition( + name="symbol_id", + type="long", + options={ + "minValue": 676, + "maxValue": 676 + num_symbols - 1 + } + ), + + # Random value helper (omitted from output) + ColumnDefinition( + name="rand_value", + type="float", + options={ + "minValue": 0.0, + "maxValue": 1.0, + "step": 0.1 + }, + baseColumn="symbol_id", + omit=True + ), + + # Stock symbol (derived from symbol_id using base-26 conversion) + ColumnDefinition( + name="symbol", + type="string", + options={ + "expr": """concat_ws('', transform(split(conv(symbol_id, 10, 26), ''), + x -> case when ascii(x) < 10 then char(ascii(x) - 48 + 65) else char(ascii(x) + 10) end))""" + } + ), + + # Days offset from start date (omitted from output) + ColumnDefinition( + name="days_from_start_date", + type="int", + options={ + "expr": f"floor(try_divide(id, {num_symbols}))" + }, + omit=True + ), + + # Post date (trading date) + ColumnDefinition( + name="post_date", + type="date", + options={ + "expr": f"date_add(cast('{start_date}' as date), days_from_start_date)" + } + ), + + # Starting price for each symbol (omitted from output) + ColumnDefinition( + name="start_value", + type="decimal", + options={ + "values": start_values + }, + omit=True + ), + + # Growth rate for each symbol + ColumnDefinition( + name="growth_rate", + type="float", + options={ + "values": growth_rates + }, + baseColumn="symbol_id" + ), + + # Volatility for each symbol (omitted from output) + ColumnDefinition( + name="volatility", + type="float", + options={ + "values": volatility_values + }, + baseColumn="symbol_id", + omit=True + ), + + # Previous day's modifier sign (omitted from output) + ColumnDefinition( + name="prev_modifier_sign", + type="float", + options={ + "expr": f"case when sin((id - {num_symbols}) % 17) > 0 then -1.0 else 1.0 end" + }, + omit=True + ), + + # Current day's modifier sign (omitted from output) + ColumnDefinition( + name="modifier_sign", + type="float", + options={ + "expr": "case when sin(id % 17) > 0 then -1.0 else 1.0 end" + }, + omit=True + ), + + # Base opening price (omitted from output) + ColumnDefinition( + name="open_base", + type="decimal", + options={ + "expr": f"""start_value + + (volatility * prev_modifier_sign * start_value * sin((id - {num_symbols}) % 17)) + + (growth_rate * start_value * try_divide(days_from_start_date - 1, 365))""" + }, + omit=True + ), + + # Base closing price (omitted from output) + ColumnDefinition( + name="close_base", + type="decimal", + options={ + "expr": """start_value + + (volatility * start_value * sin(id % 17)) + + (growth_rate * start_value * try_divide(days_from_start_date, 365))""" + }, + omit=True + ), + + # Base high price (omitted from output) + ColumnDefinition( + name="high_base", + type="decimal", + options={ + "expr": "greatest(open_base, close_base) + rand() * volatility * open_base" + }, + omit=True + ), + + # Base low price (omitted from output) + ColumnDefinition( + name="low_base", + type="decimal", + options={ + "expr": "least(open_base, close_base) - rand() * volatility * open_base" + }, + omit=True + ), + + # Final opening price (output column) + ColumnDefinition( + name="open", + type="decimal", + options={ + "expr": "greatest(open_base, 0.0)" + } + ), + + # Final closing price (output column) + ColumnDefinition( + name="close", + type="decimal", + options={ + "expr": "greatest(close_base, 0.0)" + } + ), + + # Final high price (output column) + ColumnDefinition( + name="high", + type="decimal", + options={ + "expr": "greatest(high_base, 0.0)" + } + ), + + # Final low price (output column) + ColumnDefinition( + name="low", + type="decimal", + options={ + "expr": "greatest(low_base, 0.0)" + } + ), + + # Dividend (omitted from output) + ColumnDefinition( + name="dividend", + type="decimal", + options={ + "expr": "0.05 * rand_value * close" + }, + omit=True + ), + + # Adjusted closing price (output column) + ColumnDefinition( + name="adj_close", + type="decimal", + options={ + "expr": "greatest(close - dividend, 0.0)" + } + ), + + # Trading volume (output column) + ColumnDefinition( + name="volume", + type="long", + options={ + "minValue": 100000, + "maxValue": 5000000, + "random": True + } + ), + ] + + table_def = TableDefinition( + number_of_rows=number_of_rows, + partitions=partitions, + columns=columns + ) + + spec = DatagenSpec( + tables={"stock_tickers": table_def}, + output_destination=None, # No automatic persistence + generator_options={ + "randomSeedMethod": "hash_fieldname" + } + ) + + return spec + + +# Pre-configured specs for common use cases +BASIC_STOCK_TICKER_SPEC_SMALL = create_basic_stock_ticker_spec( + number_of_rows=1000, + num_symbols=10, + start_date="2024-10-01" +) +"""Pre-configured spec for small dataset (1,000 rows, 10 symbols)""" + +BASIC_STOCK_TICKER_SPEC_MEDIUM = create_basic_stock_ticker_spec( + number_of_rows=100000, + num_symbols=100, + start_date="2024-10-01" +) +"""Pre-configured spec for medium dataset (100,000 rows, 100 symbols)""" + +BASIC_STOCK_TICKER_SPEC_LARGE = create_basic_stock_ticker_spec( + number_of_rows=1000000, + num_symbols=500, + start_date="2024-01-01" +) +"""Pre-configured spec for large dataset (1,000,000 rows, 500 symbols, full year)""" + +BASIC_STOCK_TICKER_SPEC_ONE_YEAR = create_basic_stock_ticker_spec( + number_of_rows=36500, # 100 symbols * 365 days + num_symbols=100, + start_date="2024-01-01" +) +"""Pre-configured spec for one year of daily data (100 symbols, 365 days)""" + + + diff --git a/examples/datagen_from_specs/basic_user_datagen_spec.py b/examples/datagen_from_specs/basic_user_datagen_spec.py new file mode 100644 index 00000000..ef0077e2 --- /dev/null +++ b/examples/datagen_from_specs/basic_user_datagen_spec.py @@ -0,0 +1,109 @@ +"""DatagenSpec for Basic User Dataset. + +This module defines a declarative Pydantic-based specification for generating +the basic user dataset, corresponding to the BasicUserProvider. +""" + +from dbldatagen.spec.generator_spec import DatagenSpec, TableDefinition +from dbldatagen.spec.column_spec import ColumnDefinition + + +def create_basic_user_spec( + number_of_rows: int = 100000, + partitions: int | None = None, + random: bool = False +) -> DatagenSpec: + """Create a DatagenSpec for basic user data generation. + + This function creates a declarative specification matching the data generated + by BasicUserProvider in the datasets module. + + Args: + number_of_rows: Total number of rows to generate (default: 100,000) + partitions: Number of Spark partitions to use (default: auto-computed) + random: If True, generates random data; if False, uses deterministic patterns + + Returns: + DatagenSpec configured for basic user data generation + + Example: + >>> spec = create_basic_user_spec(number_of_rows=1000, random=True) + >>> spec.validate() + >>> # Use with GeneratorSpecImpl to generate data + """ + MAX_LONG = 9223372036854775807 + + columns = [ + ColumnDefinition( + name="customer_id", + type="long", + options={ + "minValue": 1000000, + "maxValue": MAX_LONG, + "random": random + } + ), + ColumnDefinition( + name="name", + type="string", + options={ + "template": r"\w \w|\w \w \w", + "random": random + } + ), + ColumnDefinition( + name="email", + type="string", + options={ + "template": r"\w.\w@\w.com|\w@\w.co.u\k", + "random": random + } + ), + ColumnDefinition( + name="ip_addr", + type="string", + options={ + "template": r"\n.\n.\n.\n", + "random": random + } + ), + ColumnDefinition( + name="phone", + type="string", + options={ + "template": r"(ddd)-ddd-dddd|1(ddd) ddd-dddd|ddd ddddddd", + "random": random + } + ), + ] + + table_def = TableDefinition( + number_of_rows=number_of_rows, + partitions=partitions, + columns=columns + ) + + spec = DatagenSpec( + tables={"users": table_def}, + output_destination=None, # No automatic persistence + generator_options={ + "randomSeedMethod": "hash_fieldname" + } + ) + + return spec + + +# Pre-configured specs for common use cases +BASIC_USER_SPEC_SMALL = create_basic_user_spec(number_of_rows=1000, random=False) +"""Pre-configured spec for small dataset (1,000 rows, deterministic)""" + +BASIC_USER_SPEC_MEDIUM = create_basic_user_spec(number_of_rows=100000, random=False) +"""Pre-configured spec for medium dataset (100,000 rows, deterministic)""" + +BASIC_USER_SPEC_LARGE = create_basic_user_spec(number_of_rows=1000000, random=False) +"""Pre-configured spec for large dataset (1,000,000 rows, deterministic)""" + +BASIC_USER_SPEC_RANDOM = create_basic_user_spec(number_of_rows=100000, random=True) +"""Pre-configured spec for random data (100,000 rows, random)""" + diff --git a/tests/test_datagen_specs.py b/tests/test_datagen_specs.py new file mode 100644 index 00000000..105c3eeb --- /dev/null +++ b/tests/test_datagen_specs.py @@ -0,0 +1,280 @@ +"""Tests for DatagenSpec specifications for datasets.""" + +import unittest + +# Import DatagenSpec classes directly to avoid Spark initialization +from dbldatagen.spec.generator_spec import DatagenSpec, TableDefinition +from dbldatagen.spec.column_spec import ColumnDefinition + + +class TestBasicUserDatagenSpec(unittest.TestCase): + """Tests for BasicUser DatagenSpec.""" + + def test_basic_user_spec_creation(self): + """Test creating a basic user DatagenSpec.""" + columns = [ + ColumnDefinition( + name="customer_id", + type="long", + options={"minValue": 1000000, "maxValue": 9999999999} + ), + ColumnDefinition( + name="name", + type="string", + options={"template": r"\w \w"} + ), + ColumnDefinition( + name="email", + type="string", + options={"template": r"\w@\w.com"} + ), + ] + + table_def = TableDefinition( + number_of_rows=1000, + partitions=2, + columns=columns + ) + + spec = DatagenSpec( + tables={"users": table_def}, + output_destination=None + ) + + self.assertIsNotNone(spec) + self.assertIn("users", spec.tables) + self.assertEqual(spec.tables["users"].number_of_rows, 1000) + self.assertEqual(spec.tables["users"].partitions, 2) + self.assertEqual(len(spec.tables["users"].columns), 3) + + def test_basic_user_spec_validation(self): + """Test validating a basic user DatagenSpec.""" + columns = [ + ColumnDefinition( + name="customer_id", + type="long", + options={"minValue": 1000000} + ), + ColumnDefinition( + name="name", + type="string", + options={"template": r"\w \w"} + ), + ] + + table_def = TableDefinition( + number_of_rows=100, + columns=columns + ) + + spec = DatagenSpec( + tables={"users": table_def} + ) + + validation_result = spec.validate(strict=False) + self.assertTrue(validation_result.is_valid()) + self.assertEqual(len(validation_result.errors), 0) + + def test_column_with_base_column(self): + """Test creating columns that depend on other columns.""" + columns = [ + ColumnDefinition( + name="symbol_id", + type="long", + options={"minValue": 1, "maxValue": 100} + ), + ColumnDefinition( + name="symbol", + type="string", + options={ + "expr": "concat('SYM', symbol_id)" + } + ), + ] + + table_def = TableDefinition( + number_of_rows=50, + columns=columns + ) + + spec = DatagenSpec( + tables={"symbols": table_def} + ) + + validation_result = spec.validate(strict=False) + self.assertTrue(validation_result.is_valid()) + + +class TestBasicStockTickerDatagenSpec(unittest.TestCase): + """Tests for BasicStockTicker DatagenSpec.""" + + def test_basic_stock_ticker_spec_creation(self): + """Test creating a basic stock ticker DatagenSpec.""" + columns = [ + ColumnDefinition( + name="symbol", + type="string", + options={"template": r"\u\u\u"} + ), + ColumnDefinition( + name="post_date", + type="date", + options={"expr": "date_add(cast('2024-10-01' as date), floor(id / 100))"} + ), + ColumnDefinition( + name="open", + type="decimal", + options={"minValue": 100.0, "maxValue": 500.0} + ), + ColumnDefinition( + name="close", + type="decimal", + options={"minValue": 100.0, "maxValue": 500.0} + ), + ColumnDefinition( + name="volume", + type="long", + options={"minValue": 100000, "maxValue": 5000000} + ), + ] + + table_def = TableDefinition( + number_of_rows=1000, + partitions=2, + columns=columns + ) + + spec = DatagenSpec( + tables={"stock_tickers": table_def}, + output_destination=None + ) + + self.assertIsNotNone(spec) + self.assertIn("stock_tickers", spec.tables) + self.assertEqual(spec.tables["stock_tickers"].number_of_rows, 1000) + self.assertEqual(len(spec.tables["stock_tickers"].columns), 5) + + def test_stock_ticker_with_omitted_columns(self): + """Test creating spec with omitted intermediate columns.""" + columns = [ + ColumnDefinition( + name="base_price", + type="decimal", + options={"minValue": 100.0, "maxValue": 500.0}, + omit=True # Intermediate column + ), + ColumnDefinition( + name="open", + type="decimal", + options={"expr": "base_price * 0.99"} + ), + ColumnDefinition( + name="close", + type="decimal", + options={"expr": "base_price * 1.01"} + ), + ] + + table_def = TableDefinition( + number_of_rows=100, + columns=columns + ) + + spec = DatagenSpec( + tables={"prices": table_def} + ) + + validation_result = spec.validate(strict=False) + self.assertTrue(validation_result.is_valid()) + + # Check that omitted column is present + omitted_cols = [col for col in columns if col.omit] + self.assertEqual(len(omitted_cols), 1) + self.assertEqual(omitted_cols[0].name, "base_price") + + +class TestDatagenSpecValidation(unittest.TestCase): + """Tests for DatagenSpec validation.""" + + def test_empty_tables_validation(self): + """Test that spec with no tables fails validation.""" + spec = DatagenSpec(tables={}) + + with self.assertRaises(ValueError) as context: + spec.validate(strict=False) + + # Verify error message mentions missing tables + self.assertIn("at least one table", str(context.exception)) + + def test_duplicate_column_names(self): + """Test that duplicate column names are caught.""" + columns = [ + ColumnDefinition(name="id", type="long"), + ColumnDefinition(name="id", type="string"), # Duplicate! + ] + + table_def = TableDefinition( + number_of_rows=100, + columns=columns + ) + + spec = DatagenSpec(tables={"test": table_def}) + + with self.assertRaises(ValueError) as context: + spec.validate(strict=False) + + # Verify the error message mentions duplicates + self.assertIn("duplicate column names", str(context.exception)) + self.assertIn("id", str(context.exception)) + + + def test_negative_rows_validation(self): + """Test that negative row counts fail validation.""" + columns = [ + ColumnDefinition(name="col1", type="long") + ] + + # Create with negative rows using dict to bypass Pydantic validation + table_def = TableDefinition( + number_of_rows=-100, # Invalid + columns=columns + ) + + spec = DatagenSpec(tables={"test": table_def}) + + with self.assertRaises(ValueError) as context: + spec.validate(strict=False) + + # Verify error message mentions invalid number_of_rows + self.assertIn("invalid number_of_rows", str(context.exception)) + self.assertIn("-100", str(context.exception)) + + def test_spec_with_generator_options(self): + """Test creating spec with generator options.""" + columns = [ + ColumnDefinition(name="value", type="long") + ] + + table_def = TableDefinition( + number_of_rows=100, + columns=columns + ) + + spec = DatagenSpec( + tables={"test": table_def}, + generator_options={ + "randomSeedMethod": "hash_fieldname", + "verbose": True + } + ) + + self.assertIsNotNone(spec.generator_options) + self.assertEqual(spec.generator_options["randomSeedMethod"], "hash_fieldname") + self.assertTrue(spec.generator_options["verbose"]) + + +if __name__ == "__main__": + unittest.main() + + + diff --git a/tests/test_datasets_with_specs.py b/tests/test_datasets_with_specs.py new file mode 100644 index 00000000..549260f5 --- /dev/null +++ b/tests/test_datasets_with_specs.py @@ -0,0 +1,212 @@ +"""Tests for Pydantic dataset specification models.""" + +import unittest +from datetime import date +from decimal import Decimal + +# Import Pydantic directly to avoid Spark initialization issues in test environment +try: + from pydantic.v1 import BaseModel, Field, ValidationError +except ImportError: + from pydantic import BaseModel, Field, ValidationError # type: ignore + + +class TestBasicUserSpec(unittest.TestCase): + """Tests for BasicUser Pydantic model.""" + + def setUp(self): + """Set up test fixtures - define model inline to avoid import issues.""" + # Define the model inline to avoid triggering Spark imports + class BasicUser(BaseModel): + customer_id: int = Field(..., ge=1000000) + name: str = Field(..., min_length=1) + email: str + ip_addr: str + phone: str + + self.BasicUser = BasicUser + + def test_valid_user_creation(self): + """Test creating a valid user instance.""" + user = self.BasicUser( + customer_id=1234567890, + name="John Doe", + email="john.doe@example.com", + ip_addr="192.168.1.100", + phone="(555)-123-4567" + ) + + self.assertEqual(user.customer_id, 1234567890) + self.assertEqual(user.name, "John Doe") + self.assertEqual(user.email, "john.doe@example.com") + self.assertEqual(user.ip_addr, "192.168.1.100") + self.assertEqual(user.phone, "(555)-123-4567") + + def test_invalid_customer_id(self): + """Test that small customer_id is rejected.""" + with self.assertRaises(ValidationError) as context: + self.BasicUser( + customer_id=100, # Too small + name="Jane Smith", + email="jane@example.com", + ip_addr="10.0.0.1", + phone="555-1234" + ) + + error = context.exception + self.assertIn("customer_id", str(error)) + + def test_user_dict_conversion(self): + """Test converting user to dictionary.""" + user = self.BasicUser( + customer_id=1234567890, + name="John Doe", + email="john.doe@example.com", + ip_addr="192.168.1.100", + phone="(555)-123-4567" + ) + + user_dict = user.dict() + self.assertIsInstance(user_dict, dict) + self.assertEqual(user_dict["customer_id"], 1234567890) + self.assertEqual(user_dict["name"], "John Doe") + + def test_user_json_serialization(self): + """Test JSON serialization.""" + user = self.BasicUser( + customer_id=1234567890, + name="John Doe", + email="john.doe@example.com", + ip_addr="192.168.1.100", + phone="(555)-123-4567" + ) + + json_str = user.json() + self.assertIsInstance(json_str, str) + self.assertIn("1234567890", json_str) + self.assertIn("John Doe", json_str) + + # Test parsing back + user_from_json = self.BasicUser.parse_raw(json_str) + self.assertEqual(user_from_json.customer_id, user.customer_id) + self.assertEqual(user_from_json.name, user.name) + + +class TestBasicStockTickerSpec(unittest.TestCase): + """Tests for BasicStockTicker Pydantic model.""" + + def setUp(self): + """Set up test fixtures - define model inline to avoid import issues.""" + class BasicStockTicker(BaseModel): + symbol: str = Field(..., min_length=1, max_length=10) + post_date: date + open: Decimal = Field(..., ge=0) + close: Decimal = Field(..., ge=0) + high: Decimal = Field(..., ge=0) + low: Decimal = Field(..., ge=0) + adj_close: Decimal = Field(..., ge=0) + volume: int = Field(..., ge=0) + + self.BasicStockTicker = BasicStockTicker + + def test_valid_ticker_creation(self): + """Test creating a valid stock ticker instance.""" + ticker = self.BasicStockTicker( + symbol="AAPL", + post_date=date(2024, 10, 15), + open=Decimal("150.25"), + close=Decimal("152.50"), + high=Decimal("153.75"), + low=Decimal("149.80"), + adj_close=Decimal("152.35"), + volume=2500000 + ) + + self.assertEqual(ticker.symbol, "AAPL") + self.assertEqual(ticker.post_date, date(2024, 10, 15)) + self.assertEqual(ticker.open, Decimal("150.25")) + self.assertEqual(ticker.close, Decimal("152.50")) + self.assertEqual(ticker.high, Decimal("153.75")) + self.assertEqual(ticker.low, Decimal("149.80")) + self.assertEqual(ticker.adj_close, Decimal("152.35")) + self.assertEqual(ticker.volume, 2500000) + + def test_invalid_volume(self): + """Test that negative volume is rejected.""" + with self.assertRaises(ValidationError) as context: + self.BasicStockTicker( + symbol="MSFT", + post_date=date(2024, 10, 16), + open=Decimal("300.00"), + close=Decimal("305.00"), + high=Decimal("310.00"), + low=Decimal("295.00"), + adj_close=Decimal("304.50"), + volume=-1000 # Negative + ) + + error = context.exception + self.assertIn("volume", str(error)) + + def test_invalid_negative_price(self): + """Test that negative prices are rejected.""" + with self.assertRaises(ValidationError) as context: + self.BasicStockTicker( + symbol="GOOGL", + post_date=date(2024, 10, 17), + open=Decimal("-100.00"), # Negative + close=Decimal("305.00"), + high=Decimal("310.00"), + low=Decimal("295.00"), + adj_close=Decimal("304.50"), + volume=1000000 + ) + + error = context.exception + self.assertIn("open", str(error)) + + def test_ticker_dict_conversion(self): + """Test converting ticker to dictionary.""" + ticker = self.BasicStockTicker( + symbol="AAPL", + post_date=date(2024, 10, 15), + open=Decimal("150.25"), + close=Decimal("152.50"), + high=Decimal("153.75"), + low=Decimal("149.80"), + adj_close=Decimal("152.35"), + volume=2500000 + ) + + ticker_dict = ticker.dict() + self.assertIsInstance(ticker_dict, dict) + self.assertEqual(ticker_dict["symbol"], "AAPL") + self.assertEqual(ticker_dict["volume"], 2500000) + + def test_ticker_json_serialization(self): + """Test JSON serialization.""" + ticker = self.BasicStockTicker( + symbol="AAPL", + post_date=date(2024, 10, 15), + open=Decimal("150.25"), + close=Decimal("152.50"), + high=Decimal("153.75"), + low=Decimal("149.80"), + adj_close=Decimal("152.35"), + volume=2500000 + ) + + json_str = ticker.json() + self.assertIsInstance(json_str, str) + self.assertIn("AAPL", json_str) + self.assertIn("2024-10-15", json_str) + + # Test parsing back + ticker_from_json = self.BasicStockTicker.parse_raw(json_str) + self.assertEqual(ticker_from_json.symbol, ticker.symbol) + self.assertEqual(ticker_from_json.post_date, ticker.post_date) + + +if __name__ == "__main__": + unittest.main() + From e139c8b836571e430a04219cc49f60e5b5d4640f Mon Sep 17 00:00:00 2001 From: Anup Kalburgi Date: Mon, 17 Nov 2025 15:15:47 -0500 Subject: [PATCH 07/20] converting to camelCase --- dbldatagen/spec/__init__.py | 30 +++++++++++++++----------- dbldatagen/spec/generator_spec.py | 9 ++++---- dbldatagen/spec/generator_spec_impl.py | 22 +++++++++---------- dbldatagen/spec/output_targets.py | 6 ++++-- 4 files changed, 37 insertions(+), 30 deletions(-) diff --git a/dbldatagen/spec/__init__.py b/dbldatagen/spec/__init__.py index dde2d8a7..ef6a5f2a 100644 --- a/dbldatagen/spec/__init__.py +++ b/dbldatagen/spec/__init__.py @@ -4,9 +4,12 @@ in a type-safe, declarative way. """ +from typing import Any + # Import only the compat layer by default to avoid triggering Spark/heavy dependencies from .compat import BaseModel, Field, constr, root_validator, validator + # Lazy imports for heavy modules - import these explicitly when needed # from .column_spec import ColumnSpec # from .generator_spec import GeneratorSpec @@ -14,26 +17,29 @@ __all__ = [ "BaseModel", + "ColumnDefinition", + "DatagenSpec", "Field", + "Generator", "constr", "root_validator", "validator", - "ColumnSpec", - "GeneratorSpec", - "GeneratorSpecImpl", ] -def __getattr__(name): - """Lazy import heavy modules to avoid triggering Spark initialization.""" +def __getattr__(name: str) -> Any: # noqa: ANN401 + """Lazy import heavy modules to avoid triggering Spark initialization. + + Note: Imports are intentionally inside this function to enable lazy loading + and avoid importing heavy dependencies (pandas, IPython, Spark) until needed. + """ if name == "ColumnSpec": - from .column_spec import ColumnSpec - return ColumnSpec + from .column_spec import ColumnDefinition # noqa: PLC0415 + return ColumnDefinition elif name == "GeneratorSpec": - from .generator_spec import GeneratorSpec - return GeneratorSpec + from .generator_spec import DatagenSpec # noqa: PLC0415 + return DatagenSpec elif name == "GeneratorSpecImpl": - from .generator_spec_impl import GeneratorSpecImpl - return GeneratorSpecImpl + from .generator_spec_impl import Generator # noqa: PLC0415 + return Generator raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - diff --git a/dbldatagen/spec/generator_spec.py b/dbldatagen/spec/generator_spec.py index 67f460b7..2783ff7d 100644 --- a/dbldatagen/spec/generator_spec.py +++ b/dbldatagen/spec/generator_spec.py @@ -1,19 +1,18 @@ from __future__ import annotations import logging -from typing import Any, Literal, Union +from typing import Any, Union import pandas as pd from IPython.display import HTML, display from dbldatagen.spec.column_spec import ColumnDefinition -from .compat import BaseModel, validator -from .output_targets import UCSchemaTarget, FilePathTarget - -logger = logging.getLogger(__name__) +from .compat import BaseModel +from .output_targets import FilePathTarget, UCSchemaTarget +logger = logging.getLogger(__name__) class TableDefinition(BaseModel): diff --git a/dbldatagen/spec/generator_spec_impl.py b/dbldatagen/spec/generator_spec_impl.py index fc53863e..513e54d8 100644 --- a/dbldatagen/spec/generator_spec_impl.py +++ b/dbldatagen/spec/generator_spec_impl.py @@ -60,7 +60,7 @@ def __init__(self, spark: SparkSession, app_name: str = "DataGen_ClassBased") -> self.app_name = app_name logger.info("Generator initialized with SparkSession") - def _columnspec_to_datagen_columnspec(self, col_def: ColumnDefinition) -> dict[str, Any]: + def _columnSpecToDatagenColumnSpec(self, col_def: ColumnDefinition) -> dict[str, Any]: """Convert a ColumnDefinition spec into dbldatagen DataGenerator column arguments. This internal method translates the declarative ColumnDefinition format into the @@ -124,7 +124,7 @@ def _columnspec_to_datagen_columnspec(self, col_def: ColumnDefinition) -> dict[s return kwargs - def _prepare_data_generators( + def _prepareDataGenerators( self, config: DatagenSpec, config_source_name: str = "PydanticConfig" @@ -151,7 +151,7 @@ def _prepare_data_generators( :raises ValueError: If table configuration is invalid (should be caught by validate() first) .. note:: - This is an internal method. Use generate_and_write_data() for the complete workflow + This is an internal method. Use generateAndWriteData() for the complete workflow .. note:: Preparation is separate from building to allow inspection and modification of @@ -188,7 +188,7 @@ def _prepare_data_generators( # Process each column for col_def in table_spec.columns: - kwargs = self._columnspec_to_datagen_columnspec(col_def) + kwargs = self._columnSpecToDatagenColumnSpec(col_def) data_gen = data_gen.withColumn(colName=col_def.name, **kwargs) # Has performance implications. @@ -203,7 +203,7 @@ def _prepare_data_generators( logger.info("All data generators prepared successfully") return prepared_generators - def write_prepared_data( + def writePreparedData( self, prepared_generators: dict[str, dg.DataGenerator], output_destination: Union[UCSchemaTarget, FilePathTarget, None], @@ -224,7 +224,7 @@ def write_prepared_data( 4. Logs row counts and write locations :param prepared_generators: Dictionary mapping table names to DataGenerator objects - (typically from _prepare_data_generators()) + (typically from _prepareDataGenerators()) :param output_destination: Target location for output. Can be UCSchemaTarget, FilePathTarget, or None (no write, data generated only) :param config_source_name: Descriptive name for the config source, used in logging @@ -275,7 +275,7 @@ def write_prepared_data( raise RuntimeError(f"Failed to write table '{table_name}': {e}") from e logger.info("All data writes completed successfully") - def generate_and_write_data( + def generateAndWriteData( self, config: DatagenSpec, config_source_name: str = "PydanticConfig" @@ -293,7 +293,7 @@ def generate_and_write_data( 5. Logs progress and completion status This method is the recommended entry point for most use cases. For more control over - the generation process, use _prepare_data_generators() and write_prepared_data() separately. + the generation process, use _prepareDataGenerators() and writePreparedData() separately. :param config: DatagenSpec object defining tables, columns, and output destination. Should be validated with config.validate() before calling this method @@ -317,13 +317,13 @@ def generate_and_write_data( ... ) >>> spec.validate() # Check for errors first >>> generator = Generator(spark) - >>> generator.generate_and_write_data(spec) + >>> generator.generateAndWriteData(spec) """ logger.info(f"Starting combined data generation and writing for {len(config.tables)} tables") try: # Phase 1: Prepare data generators - prepared_generators_map = self._prepare_data_generators(config, config_source_name) + prepared_generators_map = self._prepareDataGenerators(config, config_source_name) if not prepared_generators_map and list(config.tables.keys()): logger.warning( @@ -331,7 +331,7 @@ def generate_and_write_data( return # Phase 2: Write data - self.write_prepared_data( + self.writePreparedData( prepared_generators_map, config.output_destination, config_source_name diff --git a/dbldatagen/spec/output_targets.py b/dbldatagen/spec/output_targets.py index f9f51194..b403304a 100644 --- a/dbldatagen/spec/output_targets.py +++ b/dbldatagen/spec/output_targets.py @@ -1,6 +1,8 @@ -from .compat import BaseModel, validator -from typing import Literal import logging +from typing import Literal + +from .compat import BaseModel, validator + logger = logging.getLogger(__name__) From 52a4283722925ebb5e37f725f890c5e141ce4db9 Mon Sep 17 00:00:00 2001 From: Anup Kalburgi Date: Mon, 17 Nov 2025 15:32:04 -0500 Subject: [PATCH 08/20] validation into a diff module --- dbldatagen/spec/generator_spec.py | 66 +------------------------------ dbldatagen/spec/validation.py | 64 ++++++++++++++++++++++++++++++ tests/test_specs.py | 3 +- 3 files changed, 67 insertions(+), 66 deletions(-) create mode 100644 dbldatagen/spec/validation.py diff --git a/dbldatagen/spec/generator_spec.py b/dbldatagen/spec/generator_spec.py index 2783ff7d..735c5b3e 100644 --- a/dbldatagen/spec/generator_spec.py +++ b/dbldatagen/spec/generator_spec.py @@ -10,6 +10,7 @@ from .compat import BaseModel from .output_targets import FilePathTarget, UCSchemaTarget +from .validation import ValidationResult logger = logging.getLogger(__name__) @@ -41,71 +42,6 @@ class TableDefinition(BaseModel): columns: list[ColumnDefinition] -class ValidationResult: - """Container for validation results that collects errors and warnings during spec validation. - - This class accumulates validation issues found while checking a DatagenSpec configuration. - It distinguishes between errors (which prevent data generation) and warnings (which - indicate potential issues but don't block generation). - - .. note:: - Validation passes if there are no errors, even if warnings are present - """ - - def __init__(self) -> None: - """Initialize an empty ValidationResult with no errors or warnings.""" - self.errors: list[str] = [] - self.warnings: list[str] = [] - - def add_error(self, message: str) -> None: - """Add an error message to the validation results. - - Errors indicate critical issues that will prevent successful data generation. - - :param message: Descriptive error message explaining the validation failure - """ - self.errors.append(message) - - def add_warning(self, message: str) -> None: - """Add a warning message to the validation results. - - Warnings indicate potential issues or non-optimal configurations that may affect - data generation but won't prevent it from completing. - - :param message: Descriptive warning message explaining the potential issue - """ - self.warnings.append(message) - - def is_valid(self) -> bool: - """Check if validation passed without errors. - - :returns: True if there are no errors (warnings are allowed), False otherwise - """ - return len(self.errors) == 0 - - def __str__(self) -> str: - """Generate a formatted string representation of all validation results. - - :returns: Multi-line string containing formatted errors and warnings with counts - """ - lines = [] - if self.is_valid(): - lines.append("✓ Validation passed successfully") - else: - lines.append("✗ Validation failed") - - if self.errors: - lines.append(f"\nErrors ({len(self.errors)}):") - for i, error in enumerate(self.errors, 1): - lines.append(f" {i}. {error}") - - if self.warnings: - lines.append(f"\nWarnings ({len(self.warnings)}):") - for i, warning in enumerate(self.warnings, 1): - lines.append(f" {i}. {warning}") - - return "\n".join(lines) - class DatagenSpec(BaseModel): """Top-level specification for synthetic data generation across one or more tables. diff --git a/dbldatagen/spec/validation.py b/dbldatagen/spec/validation.py new file mode 100644 index 00000000..ed383ffe --- /dev/null +++ b/dbldatagen/spec/validation.py @@ -0,0 +1,64 @@ +class ValidationResult: + """Container for validation results that collects errors and warnings during spec validation. + + This class accumulates validation issues found while checking a DatagenSpec configuration. + It distinguishes between errors (which prevent data generation) and warnings (which + indicate potential issues but don't block generation). + + .. note:: + Validation passes if there are no errors, even if warnings are present + """ + + def __init__(self) -> None: + """Initialize an empty ValidationResult with no errors or warnings.""" + self.errors: list[str] = [] + self.warnings: list[str] = [] + + def add_error(self, message: str) -> None: + """Add an error message to the validation results. + + Errors indicate critical issues that will prevent successful data generation. + + :param message: Descriptive error message explaining the validation failure + """ + self.errors.append(message) + + def add_warning(self, message: str) -> None: + """Add a warning message to the validation results. + + Warnings indicate potential issues or non-optimal configurations that may affect + data generation but won't prevent it from completing. + + :param message: Descriptive warning message explaining the potential issue + """ + self.warnings.append(message) + + def is_valid(self) -> bool: + """Check if validation passed without errors. + + :returns: True if there are no errors (warnings are allowed), False otherwise + """ + return len(self.errors) == 0 + + def __str__(self) -> str: + """Generate a formatted string representation of all validation results. + + :returns: Multi-line string containing formatted errors and warnings with counts + """ + lines = [] + if self.is_valid(): + lines.append("✓ Validation passed successfully") + else: + lines.append("✗ Validation failed") + + if self.errors: + lines.append(f"\nErrors ({len(self.errors)}):") + for i, error in enumerate(self.errors, 1): + lines.append(f" {i}. {error}") + + if self.warnings: + lines.append(f"\nWarnings ({len(self.warnings)}):") + for i, warning in enumerate(self.warnings, 1): + lines.append(f" {i}. {warning}") + + return "\n".join(lines) diff --git a/tests/test_specs.py b/tests/test_specs.py index d3c8ab2c..87f40c58 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -6,8 +6,9 @@ ColumnDefinition, UCSchemaTarget, FilePathTarget, - ValidationResult ) +from dbldatagen.spec.validation import ValidationResult + class TestValidationResult: """Tests for ValidationResult class""" From a0ce13ba411ec6529d2ec0692bfd9eaaea8af7f7 Mon Sep 17 00:00:00 2001 From: Anup Kalburgi Date: Tue, 2 Dec 2025 09:46:06 -0500 Subject: [PATCH 09/20] removing compat/scratch notes --- pydantic_compat.md | 101 --------------------------------------------- scratch.md | 4 -- 2 files changed, 105 deletions(-) delete mode 100644 pydantic_compat.md delete mode 100644 scratch.md diff --git a/pydantic_compat.md b/pydantic_compat.md deleted file mode 100644 index abf26e60..00000000 --- a/pydantic_compat.md +++ /dev/null @@ -1,101 +0,0 @@ -To write code that works on both Pydantic V1 and V2 and ensures a smooth future migration, you should code against the V1 API but import it through a compatibility shim. This approach uses V1's syntax, which Pydantic V2 can understand via its built-in V1 compatibility layer. - ------ - -### \#\# The Golden Rule: Code to V1, Import via a Shim 💡 - -The core strategy is to **write all your models using Pydantic V1 syntax and features**. You then use a special utility file to handle the imports, which makes your application code completely agnostic to the installed Pydantic version. - ------ - -### \#\# 1. Implement a Compatibility Shim (`compat.py`) - -This is the most critical step. Create a file named `compat.py` in your project that intelligently imports Pydantic components. Your application will import everything from this file instead of directly from `pydantic`. - -```python -# compat.py -# This module acts as a compatibility layer for Pydantic V1 and V2. - -try: - # This will succeed on environments with Pydantic V2.x - # It imports the V1 API that is bundled within V2. - from pydantic.v1 import BaseModel, Field, validator, constr - -except ImportError: - # This will be executed on environments with only Pydantic V1.x - from pydantic import BaseModel, Field, validator, constr - -# In your application code, do this: -# from .compat import BaseModel -# NOT this: -# from pydantic import BaseModel -``` - ------ - -### \#\# 2. Stick to V1 Features and Syntax (Do's and Don'ts) - -By following these rules in your application code, you ensure the logic works on both versions. - -#### **✅ Models and Fields: DO** - - * Use standard `BaseModel` and `Field` for all your data structures. This is the most stable part of the API. - -#### **❌ Models and Fields: DON'T** - - * **Do not use `__root__` models**. This V1 feature was removed in V2 and the compatibility is not perfect. Instead, model the data explicitly, even if it feels redundant. - * **Bad (Avoid):** `class MyList(BaseModel): __root__: list[str]` - * **Good (Compatible):** `class MyList(BaseModel): items: list[str]` - -#### **✅ Configuration: DO** - - * Use the nested `class Config:` for model configuration. This is the V1 way and is fully supported by the V2 compatibility layer. - * **Example:** - ```python - from .compat import BaseModel - - class User(BaseModel): - id: int - full_name: str - - class Config: - orm_mode = True # V2's compatibility layer translates this - allow_population_by_field_name = True - ``` - -#### **❌ Configuration: DON'T** - - * **Do not use the V2 `model_config` dictionary**. This is a V2-only feature. - -#### **✅ Validators and Data Types: DO** - - * Use the standard V1 `@validator`. It's robust and works perfectly across both versions. - * Use V1 constrained types like `constr`, `conint`, `conlist`. - * **Example:** - ```python - from .compat import BaseModel, validator, constr - - class Product(BaseModel): - name: constr(min_length=3) - - @validator("name") - def name_must_be_alpha(cls, v): - if not v.isalpha(): - raise ValueError("Name must be alphabetic") - return v - ``` - -#### **❌ Validators and Data Types: DON'T** - - * **Do not use V2 decorators** like `@field_validator`, `@model_validator`, or `@field_serializer`. - * **Do not use the V2 `Annotated` syntax** for validation (e.g., `Annotated[str, StringConstraints(min_length=2)]`). - ------ - -### \#\# 3. The Easy Migration Path - -When you're finally ready to leave V1 behind and upgrade your code to be V2-native, the process will be straightforward because your code is already consistent: - -1. **Change Imports**: Your first step will be a simple find-and-replace to change all `from .compat import ...` statements to `from pydantic import ...`. -2. **Run a Codelinter**: Tools like **Ruff** have built-in rules that can automatically refactor most of your V1 syntax (like `Config` classes and `@validator`s) to the new V2 syntax. -3. **Manual Refinements**: Address any complex patterns the automated tools couldn't handle, like replacing your `__root__` model alternatives. \ No newline at end of file diff --git a/scratch.md b/scratch.md deleted file mode 100644 index a3afa5c3..00000000 --- a/scratch.md +++ /dev/null @@ -1,4 +0,0 @@ -Pydantic Notes -https://docs.databricks.com/aws/en/release-notes/runtime/14.3lts - 1.10.6 -https://docs.databricks.com/aws/en/release-notes/runtime/15.4lts - 1.10.6 -https://docs.databricks.com/aws/en/release-notes/runtime/16.4lts - 2.8.2 (2.20.1 - core) \ No newline at end of file From f6c9a694ce7b2c9f98ef7ed2eb854767f69651e3 Mon Sep 17 00:00:00 2001 From: Anup Kalburgi Date: Tue, 2 Dec 2025 13:17:58 -0500 Subject: [PATCH 10/20] marking the spec module experimental --- CHANGELOG.md | 33 +++++----- README.md | 91 +++++++++++++------------- dbldatagen/spec/__init__.py | 6 +- dbldatagen/spec/column_spec.py | 3 + dbldatagen/spec/generator_spec.py | 6 ++ dbldatagen/spec/generator_spec_impl.py | 3 + 6 files changed, 80 insertions(+), 62 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dc8868b7..50fd539f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to the Databricks Labs Data Generator will be documented in ### unreleased -#### Fixed +#### Fixed * Updated build scripts to use Ubuntu 22.04 to correspond to environment in Databricks runtime * Refactored `DataAnalyzer` and `BasicStockTickerProvider` to comply with ANSI SQL standards * Removed internal modification of `SparkSession` @@ -23,6 +23,7 @@ All notable changes to the Databricks Labs Data Generator will be documented in #### Added * Added support for serialization to/from JSON format * Added Ruff and mypy tooling +* Pydantic-based specification API (Experimental) ### Version 0.4.0 Hotfix 2 @@ -59,7 +60,7 @@ All notable changes to the Databricks Labs Data Generator will be documented in * Updated docs for complex data types / JSON to correct code examples * Updated license file in public docs -#### Fixed +#### Fixed * Fixed scenario where `DataAnalyzer` is used on dataframe containing a column named `summary` ### Version 0.3.6 @@ -90,14 +91,14 @@ All notable changes to the Databricks Labs Data Generator will be documented in ### Version 0.3.4 Post 2 ### Fixed -* Fix for use of values in columns of type array, map and struct +* Fix for use of values in columns of type array, map and struct * Fix for generation of arrays via `numFeatures` and `structType` attributes when numFeatures has value of 1 ### Version 0.3.4 Post 1 ### Fixed -* Fix for use and configuration of root logger +* Fix for use and configuration of root logger ### Acknowledgements Thanks to Marvin Schenkel for the contribution @@ -120,7 +121,7 @@ Thanks to Marvin Schenkel for the contribution #### Changed * Fixed use of logger in _version.py and in spark_singleton.py -* Fixed template issues +* Fixed template issues * Document reformatting and updates, related code comment changes ### Fixed @@ -133,19 +134,19 @@ Thanks to Marvin Schenkel for the contribution ### Version 0.3.2 #### Changed -* Adjusted column build phase separation (i.e which select statement is used to build columns) so that a +* Adjusted column build phase separation (i.e which select statement is used to build columns) so that a column with a SQL expression can refer to previously created columns without use of a `baseColumn` attribute * Changed build labelling to comply with PEP440 -#### Fixed +#### Fixed * Fixed compatibility of build with older versions of runtime that rely on `pyparsing` version 2.4.7 -#### Added +#### Added * Parsing of SQL expressions to determine column dependencies #### Notes * The enhancements to build ordering does not change actual order of column building - - but adjusts which phase columns are built in + but adjusts which phase columns are built in ### Version 0.3.1 @@ -154,11 +155,11 @@ Thanks to Marvin Schenkel for the contribution * Refactoring of template text generation for better performance via vectorized implementation * Additional migration of tests to use of `pytest` -#### Fixed +#### Fixed * added type parsing support for binary and constructs such as `nvarchar(10)` -* Fixed error occurring when schema contains map, array or struct. +* Fixed error occurring when schema contains map, array or struct. -#### Added +#### Added * Ability to change name of seed column to custom name (defaults to `id`) * Added type parsing support for structs, maps and arrays and combinations of the above @@ -207,14 +208,14 @@ See the contents of the file `python/require.txt` to see the Python package depe The code for the Databricks Data Generator has the following dependencies * Requires Databricks runtime 9.1 LTS or later -* Requires Spark 3.1.2 or later +* Requires Spark 3.1.2 or later * Requires Python 3.8.10 or later -While the data generator framework does not require all libraries used by the runtimes, where a library from +While the data generator framework does not require all libraries used by the runtimes, where a library from the Databricks runtime is used, it will use the version found in the Databricks runtime for 9.1 LTS or later. You can use older versions of the Databricks Labs Data Generator by referring to that explicit version. -The recommended method to install the package is to use `pip install` in your notebook to install the package from +The recommended method to install the package is to use `pip install` in your notebook to install the package from PyPi For example: @@ -227,7 +228,7 @@ To use an older DB runtime version in your notebook, you can use the following c %pip install git+https://github.com/databrickslabs/dbldatagen@dbr_7_3_LTS_compat ``` -See the [Databricks runtime release notes](https://docs.databricks.com/release-notes/runtime/releases.html) +See the [Databricks runtime release notes](https://docs.databricks.com/release-notes/runtime/releases.html) for the full list of dependencies used by the Databricks runtime. This can be found at : https://docs.databricks.com/release-notes/runtime/releases.html diff --git a/README.md b/README.md index 77d3dec0..6387f8a1 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@ -# Databricks Labs Data Generator (`dbldatagen`) +# Databricks Labs Data Generator (`dbldatagen`) [Documentation](https://databrickslabs.github.io/dbldatagen/public_docs/index.html) | [Release Notes](CHANGELOG.md) | [Examples](examples) | -[Tutorial](tutorial) +[Tutorial](tutorial) [![build](https://github.com/databrickslabs/dbldatagen/workflows/build/badge.svg?branch=master)](https://github.com/databrickslabs/dbldatagen/actions?query=workflow%3Abuild+branch%3Amaster) @@ -14,38 +14,38 @@ [![PyPi downloads](https://img.shields.io/pypi/dm/dbldatagen?label=PyPi%20Downloads)](https://pypistats.org/packages/dbldatagen) [![lines of code](https://tokei.rs/b1/github/databrickslabs/dbldatagen)]([https://codecov.io/github/databrickslabs/dbldatagen](https://github.com/databrickslabs/dbldatagen)) - ## Project Description -The `dbldatagen` Databricks Labs project is a Python library for generating synthetic data within the Databricks -environment using Spark. The generated data may be used for testing, benchmarking, demos, and many +The `dbldatagen` Databricks Labs project is a Python library for generating synthetic data within the Databricks +environment using Spark. The generated data may be used for testing, benchmarking, demos, and many other uses. -It operates by defining a data generation specification in code that controls +It operates by defining a data generation specification in code that controls how the synthetic data is generated. The specification may incorporate the use of existing schemas or create data in an ad-hoc fashion. -It has no dependencies on any libraries that are not already installed in the Databricks +It has no dependencies on any libraries that are not already installed in the Databricks runtime, and you can use it from Scala, R or other languages by defining a view over the generated data. ### Feature Summary It supports: -* Generating synthetic data at scale up to billions of rows within minutes using appropriately sized clusters -* Generating repeatable, predictable data supporting the need for producing multiple tables, Change Data Capture, +* Generating synthetic data at scale up to billions of rows within minutes using appropriately sized clusters +* Generating repeatable, predictable data supporting the need for producing multiple tables, Change Data Capture, merge and join scenarios with consistency between primary and foreign keys -* Generating synthetic data for all of the -Spark SQL supported primitive types as a Spark data frame which may be persisted, -saved to external storage or +* Generating synthetic data for all of the +Spark SQL supported primitive types as a Spark data frame which may be persisted, +saved to external storage or used in other computations * Generating ranges of dates, timestamps, and numeric values * Generation of discrete values - both numeric and text -* Generation of values at random and based on the values of other fields +* Generation of values at random and based on the values of other fields (either based on the `hash` of the underlying values or the values themselves) -* Ability to specify a distribution for random data generation +* Ability to specify a distribution for random data generation * Generating arrays of values for ML-style feature arrays * Applying weights to the occurrence of values * Generating values to conform to a schema or independent of an existing schema @@ -53,14 +53,15 @@ used in other computations * plugin mechanism to allow use of 3rd party libraries such as Faker * Use within a Databricks Delta Live Tables pipeline as a synthetic data generation source * Generate synthetic data generation code from existing schema or data (experimental) +* Pydantic-based specification API for type-safe data generation (experimental) * Use of standard datasets for quick generation of synthetic data Details of these features can be found in the online documentation - - [online documentation](https://databrickslabs.github.io/dbldatagen/public_docs/index.html). + [online documentation](https://databrickslabs.github.io/dbldatagen/public_docs/index.html). ## Documentation -Please refer to the [online documentation](https://databrickslabs.github.io/dbldatagen/public_docs/index.html) for +Please refer to the [online documentation](https://databrickslabs.github.io/dbldatagen/public_docs/index.html) for details of use and many examples. Release notes and details of the latest changes for this specific release @@ -76,32 +77,32 @@ Within a Databricks notebook, invoke the following in a notebook cell %pip install dbldatagen ``` -The Pip install command can be invoked within a Databricks notebook, a Delta Live Tables pipeline +The Pip install command can be invoked within a Databricks notebook, a Delta Live Tables pipeline and even works on the Databricks community edition. -The documentation [installation notes](https://databrickslabs.github.io/dbldatagen/public_docs/installation_notes.html) +The documentation [installation notes](https://databrickslabs.github.io/dbldatagen/public_docs/installation_notes.html) contains details of installation using alternative mechanisms. -## Compatibility -The Databricks Labs Data Generator framework can be used with Pyspark 3.4.1 and Python 3.10.12 or later. These are +## Compatibility +The Databricks Labs Data Generator framework can be used with Pyspark 3.4.1 and Python 3.10.12 or later. These are compatible with the Databricks runtime 13.3 LTS and later releases. This version also provides Unity Catalog compatibily. -For full library compatibility for a specific Databricks Spark release, see the Databricks +For full library compatibility for a specific Databricks Spark release, see the Databricks release notes for library compatibility - https://docs.databricks.com/release-notes/runtime/releases.html -In older releases, when using the Databricks Labs Data Generator on "Unity Catalog" enabled Databricks environments, -the Data Generator requires the use of `Single User` or `No Isolation Shared` access modes when using Databricks -runtimes prior to release 13.2. This is because some needed features are not available in `Shared` -mode (for example, use of 3rd party libraries, use of Python UDFs) in these releases. +In older releases, when using the Databricks Labs Data Generator on "Unity Catalog" enabled Databricks environments, +the Data Generator requires the use of `Single User` or `No Isolation Shared` access modes when using Databricks +runtimes prior to release 13.2. This is because some needed features are not available in `Shared` +mode (for example, use of 3rd party libraries, use of Python UDFs) in these releases. Depending on settings, the `Custom` access mode may be supported for those releases. The use of Unity Catalog `Shared` access mode is supported in Databricks runtimes from Databricks runtime release 13.2 -onwards. +onwards. -*This version of the data generator uses the Databricks runtime 13.3 LTS as the minimum supported +*This version of the data generator uses the Databricks runtime 13.3 LTS as the minimum supported version and alleviates these issues.* See the following documentation for more information: @@ -109,7 +110,7 @@ See the following documentation for more information: - https://docs.databricks.com/data-governance/unity-catalog/compute.html ## Using the Data Generator -To use the data generator, install the library using the `%pip install` method or install the Python wheel directly +To use the data generator, install the library using the `%pip install` method or install the Python wheel directly in your environment. Once the library has been installed, you can use it to generate a data frame composed of synthetic data. @@ -120,7 +121,7 @@ for your use case. ```buildoutcfg import dbldatagen as dg df = dg.Datasets(spark, "basic/user").get(rows=1000_000).build() -num_rows=df.count() +num_rows=df.count() ``` You can also define fully custom data sets using the `DataGenerator` class. @@ -135,48 +136,48 @@ data_rows = 1000 * 1000 df_spec = (dg.DataGenerator(spark, name="test_data_set1", rows=data_rows, partitions=4) .withIdOutput() - .withColumn("r", FloatType(), + .withColumn("r", FloatType(), expr="floor(rand() * 350) * (86400 + 3600)", numColumns=column_count) .withColumn("code1", IntegerType(), minValue=100, maxValue=200) .withColumn("code2", IntegerType(), minValue=0, maxValue=10) .withColumn("code3", StringType(), values=['a', 'b', 'c']) - .withColumn("code4", StringType(), values=['a', 'b', 'c'], + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) - .withColumn("code5", StringType(), values=['a', 'b', 'c'], + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) - + ) - + df = df_spec.build() -num_rows=df.count() +num_rows=df.count() ``` -Refer to the [online documentation](https://databrickslabs.github.io/dbldatagen/public_docs/index.html) for further -examples. +Refer to the [online documentation](https://databrickslabs.github.io/dbldatagen/public_docs/index.html) for further +examples. The GitHub repository also contains further examples in the examples directory. ## Spark and Databricks Runtime Compatibility -The `dbldatagen` package is intended to be compatible with recent LTS versions of the Databricks runtime, including -older LTS versions at least from 13.3 LTS and later. It also aims to be compatible with Delta Live Table runtimes, -including `current` and `preview`. +The `dbldatagen` package is intended to be compatible with recent LTS versions of the Databricks runtime, including +older LTS versions at least from 13.3 LTS and later. It also aims to be compatible with Delta Live Table runtimes, +including `current` and `preview`. While we don't specifically drop support for older runtimes, changes in Pyspark APIs or APIs from dependent packages such as `numpy`, `pandas`, `pyarrow`, and `pyparsing` make cause issues with older -runtimes. +runtimes. -By design, installing `dbldatagen` does not install releases of dependent packages in order +By design, installing `dbldatagen` does not install releases of dependent packages in order to preserve the curated set of packages pre-installed in any Databricks runtime environment. When building on local environments, run `make dev` to install required dependencies. ## Project Support Please note that all projects released under [`Databricks Labs`](https://www.databricks.com/learn/labs) - are provided for your exploration only, and are not formally supported by Databricks with Service Level Agreements -(SLAs). They are provided AS-IS, and we do not make any guarantees of any kind. Please do not submit a support ticket + are provided for your exploration only, and are not formally supported by Databricks with Service Level Agreements +(SLAs). They are provided AS-IS, and we do not make any guarantees of any kind. Please do not submit a support ticket relating to any issues arising from the use of these projects. -Any issues discovered through the use of this project should be filed as issues on the GitHub Repo. +Any issues discovered through the use of this project should be filed as issues on the GitHub Repo. They will be reviewed as time permits, but there are no formal SLAs for support. diff --git a/dbldatagen/spec/__init__.py b/dbldatagen/spec/__init__.py index ef6a5f2a..afede3f6 100644 --- a/dbldatagen/spec/__init__.py +++ b/dbldatagen/spec/__init__.py @@ -1,7 +1,11 @@ -"""Pydantic-based specification API for dbldatagen. +"""Pydantic-based specification API for dbldatagen (Experimental). This module provides Pydantic models and specifications for defining data generation in a type-safe, declarative way. + +.. warning:: + Experimental - This API is experimental and both APIs and generated code + are liable to change in future versions. """ from typing import Any diff --git a/dbldatagen/spec/column_spec.py b/dbldatagen/spec/column_spec.py index 74e9e57f..c6fa20f8 100644 --- a/dbldatagen/spec/column_spec.py +++ b/dbldatagen/spec/column_spec.py @@ -54,6 +54,9 @@ class ColumnDefinition(BaseModel): "auto" (infer behavior), "hash" (hash the base column values), "values" (use base column values directly) + .. warning:: + Experimental - This API is subject to change in future versions + .. note:: Primary columns have special constraints: - Must have a type defined diff --git a/dbldatagen/spec/generator_spec.py b/dbldatagen/spec/generator_spec.py index 735c5b3e..386178cd 100644 --- a/dbldatagen/spec/generator_spec.py +++ b/dbldatagen/spec/generator_spec.py @@ -30,6 +30,9 @@ class TableDefinition(BaseModel): :param columns: List of ColumnDefinition objects specifying the columns to generate in this table. At least one column must be specified + .. warning:: + Experimental - This API is subject to change in future versions + .. note:: Setting an appropriate number of partitions can significantly impact generation performance. As a rule of thumb, use 2-4 partitions per CPU core available in your Spark cluster @@ -64,6 +67,9 @@ class DatagenSpec(BaseModel): :param intended_for_databricks: Flag indicating if this spec is designed for Databricks. May be automatically inferred based on configuration + .. warning:: + Experimental - This API is subject to change in future versions + .. note:: Call the validate() method before using this spec to ensure configuration is correct diff --git a/dbldatagen/spec/generator_spec_impl.py b/dbldatagen/spec/generator_spec_impl.py index 513e54d8..fc56699b 100644 --- a/dbldatagen/spec/generator_spec_impl.py +++ b/dbldatagen/spec/generator_spec_impl.py @@ -36,6 +36,9 @@ class Generator: :param spark: Active SparkSession to use for data generation :param app_name: Application name used in logging and tracking. Defaults to "DataGen_ClassBased" + .. warning:: + Experimental - This API is subject to change in future versions + .. note:: The Generator requires an active SparkSession. On Databricks, you can use the pre-configured `spark` variable. For local development, create a SparkSession first From 71e64516dbb9c86808d662c8d1433153856c5809 Mon Sep 17 00:00:00 2001 From: Greg Hansen <163584195+ghanse@users.noreply.github.com> Date: Tue, 23 Sep 2025 10:29:01 -0400 Subject: [PATCH 11/20] Add methods for persisting generated data (#352) * added use of ABC to mark TextGenerator as abstract * Lint text generators module * Add persistence methods * Add tests and docs; Update PR template * Update hatch installation for push action * Refactor * Update method names and signatures --------- Co-authored-by: ronanstokes-db Co-authored-by: Ronan Stokes <42389040+ronanstokes-db@users.noreply.github.com> --- .github/workflows/push.yml | 3 +- PULL_REQUEST_TEMPLATE.md | 41 ++------ dbldatagen/__init__.py | 3 +- dbldatagen/config.py | 36 +++++++ dbldatagen/data_generator.py | 53 +++++++--- dbldatagen/datasets/basic_stock_ticker.py | 3 +- dbldatagen/datasets_object.py | 50 ++++----- dbldatagen/utils.py | 47 ++++++++- docs/source/index.rst | 2 + docs/source/using_standard_datasets.rst | 10 +- docs/source/writing_generated_data.rst | 118 ++++++++++++++++++++++ tests/test_output.py | 115 +++++++++++++++++++++ tests/test_streaming.py | 2 +- 13 files changed, 404 insertions(+), 79 deletions(-) create mode 100644 dbldatagen/config.py create mode 100644 docs/source/writing_generated_data.rst create mode 100644 tests/test_output.py diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index cc04952e..376da7f8 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -61,7 +61,8 @@ jobs: python-version: '3.10' - name: Install Hatch - run: pip install hatch + # click 8.3+ introduced bug for hatch + run: pip install "hatch==1.13.0" "click<8.3" - name: Run unit tests run: make dev test diff --git a/PULL_REQUEST_TEMPLATE.md b/PULL_REQUEST_TEMPLATE.md index 50f651b4..b14a6e9b 100644 --- a/PULL_REQUEST_TEMPLATE.md +++ b/PULL_REQUEST_TEMPLATE.md @@ -1,34 +1,15 @@ -## Proposed changes +## Changes + -Describe the big picture of your changes here to communicate to the maintainers. -If it fixes a bug or resolves a feature request, please provide a link to that issue. +### Linked issues + -## Types of changes +Resolves #.. -What types of changes does your code introduce to dbldatagen? -_Put an `x` in the boxes that apply_ +### Requirements + -- [ ] Bug fix (non-breaking change which fixes an issue) -- [ ] New feature (non-breaking change which adds functionality) -- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) -- [ ] Change to tutorials, tests or examples -- [ ] Non code change (readme, images or other non-code assets) -- [ ] Documentation Update (if none of the other choices apply) - -## Checklist - -_Put an `x` in the boxes that apply. You can also fill these out after creating the PR. -If you're unsure about any of them, don't hesitate to ask. We're here to help! -This is simply a reminder of what we are going to look for before merging your code._ - -- [ ] Lint and unit tests pass locally with my changes -- [ ] I have added tests that prove my fix is effective or that my feature works -- [ ] I have added necessary documentation (if appropriate) -- [ ] Any dependent changes have been merged and published in downstream modules -- [ ] Submission does not reduce code coverage numbers -- [ ] Submission does not increase alerts or messages from prospector / lint - -## Further comments - -If this is a relatively large or complex change, kick off the discussion by explaining why you chose the solution you -did and what alternatives you considered, etc... \ No newline at end of file +- [ ] manually tested +- [ ] updated documentation +- [ ] updated demos +- [ ] updated tests diff --git a/dbldatagen/__init__.py b/dbldatagen/__init__.py index 835a59fb..3a00ce71 100644 --- a/dbldatagen/__init__.py +++ b/dbldatagen/__init__.py @@ -47,11 +47,12 @@ from .text_generator_plugins import PyfuncText, PyfuncTextFactory, FakerTextFactory, fakerText from .html_utils import HtmlUtils from .datasets_object import Datasets +from .config import OutputDataset __all__ = ["data_generator", "data_analyzer", "schema_parser", "daterange", "nrange", "column_generation_spec", "utils", "function_builder", "spark_singleton", "text_generators", "datarange", "datagen_constants", - "text_generator_plugins", "html_utils", "datasets_object", "constraints" + "text_generator_plugins", "html_utils", "datasets_object", "constraints", "config" ] diff --git a/dbldatagen/config.py b/dbldatagen/config.py new file mode 100644 index 00000000..a5fa7ecb --- /dev/null +++ b/dbldatagen/config.py @@ -0,0 +1,36 @@ +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module implements configuration classes for writing generated data. +""" +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class OutputDataset: + """ + This class implements an output sink configuration used to write generated data. An output location must be + provided. The output mode, format, and options can be provided. + + :param location: Output location for writing data. This could be an absolute path, a relative path to a Databricks + Volume, or a full table location using Unity catalog's 3-level namespace. + :param output_mode: Output mode for writing data (default is ``"append"``). + :param format: Output data format (default is ``"delta"``). + :param options: Optional dictionary of options for writing data (e.g. ``{"mergeSchema": "true"}``) + """ + location: str + output_mode: str = "append" + format: str = "delta" + options: dict[str, str] | None = None + trigger: dict[str, str] | None = None + + def __post_init__(self) -> None: + if not self.trigger: + return + + # Only processingTime is currently supported + if "processingTime" not in self.trigger: + valid_trigger_format = '{"processingTime": "10 SECONDS"}' + raise ValueError(f"Attribute 'trigger' must be a dictionary of the form '{valid_trigger_format}'") diff --git a/dbldatagen/data_generator.py b/dbldatagen/data_generator.py index f609a000..a08c5537 100644 --- a/dbldatagen/data_generator.py +++ b/dbldatagen/data_generator.py @@ -15,11 +15,13 @@ from typing import Any from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.streaming.query import StreamingQuery from pyspark.sql.types import DataType, IntegerType, LongType, StringType, StructField, StructType from dbldatagen import datagen_constants from dbldatagen._version import _get_spark_version from dbldatagen.column_generation_spec import ColumnGenerationSpec +from dbldatagen.config import OutputDataset from dbldatagen.constraints import Constraint, SqlExpr from dbldatagen.datarange import DataRange from dbldatagen.distributions import DataDistribution @@ -28,7 +30,14 @@ from dbldatagen.serialization import SerializableToDict from dbldatagen.spark_singleton import SparkSingleton from dbldatagen.text_generators import TextGenerator -from dbldatagen.utils import DataGenError, deprecated, ensure, split_list_matching_condition, topologicalSort +from dbldatagen.utils import ( + DataGenError, + deprecated, + ensure, + split_list_matching_condition, + topologicalSort, + write_data_to_output, +) _OLD_MIN_OPTION: str = "min" @@ -1204,9 +1213,9 @@ def _generateColumnDefinition( ) -> ColumnGenerationSpec: """ generate field definition and column spec - .. note:: Any time that a new column definition is added, - we'll mark that the build plan needs to be regenerated. - For our purposes, the build plan determines the order of column generation etc. + .. note:: + Any time that a new column definition is added, we'll mark that the build plan needs to be regenerated. + For our purposes, the build plan determines the order of column generation etc. :returns: Newly added column_spec """ @@ -1381,7 +1390,6 @@ def _adjustBuildOrderForSqlDependencies(self, buildOrder: list[list[str]], colum :param buildOrder: list of lists of ids - each sublist represents phase of build :param columnSpecsByName: dictionary to map column names to column specs :returns: Spark SQL dataframe of generated test data - """ new_build_order = [] @@ -1476,8 +1484,8 @@ def withConstraint(self, constraint: Constraint) -> "DataGenerator": :returns: A modified version of the current DataGenerator with the constraint applied .. note:: - Constraints are applied at the end of the data generation. Depending on the type of the constraint, the - constraint may also affect other aspects of the data generation. + Constraints are applied at the end of the data generation. Depending on the type of the constraint, the + constraint may also affect other aspects of the data generation. """ assert constraint is not None, "Constraint cannot be empty" assert isinstance(constraint, Constraint), \ @@ -1494,8 +1502,8 @@ def withConstraints(self, constraints: list[Constraint]) -> "DataGenerator": :returns: A modified version of the current `DataGenerator` with the constraints applied .. note:: - Constraints are applied at the end of the data generation. Depending on the type of the constraint, the - constraint may also affect other aspects of the data generation. + Constraints are applied at the end of the data generation. Depending on the type of the constraint, the + constraint may also affect other aspects of the data generation. """ assert constraints is not None, "Constraints list cannot be empty" @@ -1515,9 +1523,9 @@ def withSqlConstraint(self, sqlExpression: str) -> "DataGenerator": :returns: A modified version of the current `DataGenerator` with the SQL expression constraint applied .. note:: - Note in the current implementation, this may be equivalent to adding where clauses to the generated dataframe - but in future releases, this may be optimized to affect the underlying data generation so that constraints - are satisfied more efficiently. + Note in the current implementation, this may be equivalent to adding where clauses to the generated dataframe + but in future releases, this may be optimized to affect the underlying data generation so that constraints + are satisfied more efficiently. """ self.withConstraint(SqlExpr(sqlExpression)) return self @@ -1909,6 +1917,27 @@ def scriptMerge( return result + def saveAsDataset( + self, + dataset: OutputDataset, + with_streaming: bool | None = None, + generator_options: dict[str, Any] | None = None + ) -> StreamingQuery | None: + """ + Builds a `DataFrame` from the `DataGenerator` and writes the data to an output dataset (e.g. a table or files). + + :param dataset: Output dataset for writing generated data + :param with_streaming: Whether to generate data using streaming. If None, auto-detects based on trigger + :param generator_options: Options for building the generator (e.g. `{"rowsPerSecond": 100}`) + :returns: A Spark `StreamingQuery` if data is written in streaming, otherwise `None` + """ + # Auto-detect streaming mode if not explicitly specified + if with_streaming is None: + with_streaming = dataset.trigger is not None and len(dataset.trigger) > 0 + + df = self.build(withStreaming=with_streaming, options=generator_options) + return write_data_to_output(df, output_dataset=dataset) + @staticmethod def loadFromJson(options: str) -> "DataGenerator": """ diff --git a/dbldatagen/datasets/basic_stock_ticker.py b/dbldatagen/datasets/basic_stock_ticker.py index 2d5ade39..74cfe3f9 100644 --- a/dbldatagen/datasets/basic_stock_ticker.py +++ b/dbldatagen/datasets/basic_stock_ticker.py @@ -15,7 +15,7 @@ class BasicStockTickerProvider(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider): """ Basic Stock Ticker Dataset - ======================== + ========================== This is a basic stock ticker dataset with time-series `symbol`, `open`, `close`, `high`, `low`, `adj_close`, and `volume` values. @@ -31,7 +31,6 @@ class BasicStockTickerProvider(DatasetProvider.NoAssociatedDatasetsMixin, Datase Note that this dataset does not use any features that would prevent it from being used as a source for a streaming dataframe, and so the flag `supportsStreaming` is set to True. - """ DEFAULT_NUM_SYMBOLS = 100 DEFAULT_START_DATE = "2024-10-01" diff --git a/dbldatagen/datasets_object.py b/dbldatagen/datasets_object.py index fdae8f60..602e68ff 100644 --- a/dbldatagen/datasets_object.py +++ b/dbldatagen/datasets_object.py @@ -20,8 +20,9 @@ import re from dbldatagen.datasets.dataset_provider import DatasetProvider -from .spark_singleton import SparkSingleton -from .utils import strip_margins +from dbldatagen.spark_singleton import SparkSingleton +from dbldatagen.utils import strip_margins +from dbldatagen.data_generator import DataGenerator class Datasets: @@ -211,12 +212,6 @@ def get(self, table=None, rows=-1, partitions=-1, **kwargs): If the dataset supports multiple tables, the table may be specified in the `table` parameter. If none is specified, the primary table is used. - :param table: name of table to retrieve - :param rows: number of rows to generate. if -1, provider should compute defaults. - :param partitions: number of partitions to use.If -1, the number of partitions is computed automatically - table size and partitioning.If applied to a dataset with only a single table, this is ignored. - :param kwargs: additional keyword arguments to pass to the provider - If `rows` or `partitions` are not specified, default values are supplied by the provider. For multi-table datasets, the table name must be specified. For single table datasets, the table name may @@ -225,41 +220,44 @@ def get(self, table=None, rows=-1, partitions=-1, **kwargs): Additionally, for multi-table datasets, the table name must be one of the tables supported by the provider. Default number of rows for multi-table datasets may differ - for example a 'customers' table may have a 100,000 rows while a 'sales' table may have 1,000,000 rows. + + :param table: name of table to retrieve + :param rows: number of rows to generate. if -1, provider should compute defaults. + :param partitions: number of partitions to use.If -1, the number of partitions is computed automatically + table size and partitioning.If applied to a dataset with only a single table, this is ignored. + :param kwargs: additional keyword arguments to pass to the provider + :returns: table generator """ return self._get(providerName=self._name, tableName=table, rows=rows, partitions=partitions, **kwargs) def _getSupportingTable(self, *, providerName, tableName, rows=-1, partitions=-1, **kwargs): - providerInstance, providerDefinition = \ - self._getProviderInstanceAndMetadata(providerName, supportsStreaming=self._streamingRequired) + providerInstance, providerDefinition = self._getProviderInstanceAndMetadata( + providerName, supportsStreaming=self._streamingRequired + ) assert tableName is not None and len(tableName.strip()) > 0, "Data set name must be provided" if tableName not in providerDefinition.associatedDatasets: raise ValueError(f"Dataset `{tableName}` not a recognized dataset option") - dfSupportingTable = providerInstance.getAssociatedDataset(self._sparkSession, tableName=tableName, rows=rows, - partitions=partitions, - **kwargs) + dfSupportingTable = providerInstance.getAssociatedDataset( + self._sparkSession, tableName=tableName, rows=rows, partitions=partitions, **kwargs + ) return dfSupportingTable - def getAssociatedDataset(self, *, table, rows=-1, partitions=-1, **kwargs): - """Get a table generator from the dataset provider + def getAssociatedDataset(self, *, table, rows=-1, partitions=-1, **kwargs) -> DataGenerator: + """ + Gets a table generator from the dataset provider. - These are DataGenerator instances that can be used to generate the data. + Associated datasets are DataGenerator instances that can be used to generate the data. The dataset providers also optionally can provide supporting tables which are computed tables based on parameters. These are retrieved using the `getAssociatedDataset` method If the dataset supports multiple tables, the table may be specified in the `table` parameter. If none is specified, the primary table is used. - :param table: name of table to retrieve - :param rows: number of rows to generate. if -1, provider should compute defaults. - :param partitions: number of partitions to use.If -1, the number of partitions is computed automatically - table size and partitioning.If applied to a dataset with only a single table, this is ignored. - :param kwargs: additional keyword arguments to pass to the provider - If `rows` or `partitions` are not specified, default values are supplied by the provider. For multi-table datasets, the table name must be specified. For single table datasets, the table name may @@ -269,9 +267,13 @@ def getAssociatedDataset(self, *, table, rows=-1, partitions=-1, **kwargs): Default number of rows for multi-table datasets may differ - for example a 'customers' table may have a 100,000 rows while a 'sales' table may have 1,000,000 rows. - .. note :: + :param table: Name of table to retrieve + :param rows: Number of rows to generate. if -1, provider should compute defaults + :param partitions: number of partitions to use. If -1, the number of partitions is computed automatically table + size and partitioning. If applied to a dataset with only a single table, this is ignored. - This method may also be invoked via the aliased names - `getSupportingDataset` and `getCombinedDataset` + .. note :: + This method may also be invoked via the aliased names - `getSupportingDataset` and `getCombinedDataset` """ return self._getSupportingTable(providerName=self._name, tableName=table, rows=rows, partitions=partitions, **kwargs) diff --git a/dbldatagen/utils.py b/dbldatagen/utils.py index d9b7c661..e7acaa2b 100644 --- a/dbldatagen/utils.py +++ b/dbldatagen/utils.py @@ -18,6 +18,10 @@ from typing import Any import jmespath +from pyspark.sql import DataFrame +from pyspark.sql.streaming.query import StreamingQuery + +from dbldatagen.config import OutputDataset def deprecated(message: str = "") -> Callable[[Callable[..., Any]], Callable[..., Any]]: @@ -293,10 +297,10 @@ def split_list_matching_condition(lst: list[Any], cond: Callable[[Any], bool]) - Result: `[['id'], ['city_name'], ['id'], ['city_id', 'city_pop'], - ['id'], ['city_id', 'city_pop', 'city_id', 'city_pop'], ['id']]` + ['id'], ['city_id', 'city_pop', 'city_id', 'city_pop'], ['id']]` - :arg lst: list of items to perform condition matches against - :arg cond: lambda function or function taking single argument and returning True or False + :param lst: list of items to perform condition matches against + :param cond: lambda function or function taking single argument and returning True or False :returns: list of sublists """ retval: list[list[Any]] = [] @@ -360,3 +364,40 @@ def system_time_millis() -> int: """ curr_time: int = round(time.time() / 1000) return curr_time + + +def write_data_to_output(df: DataFrame, output_dataset: OutputDataset) -> StreamingQuery | None: + """ + Writes a DataFrame to the sink configured in the output configuration. + + :param df: Spark DataFrame to write + :param output_dataset: Output dataset configuration passed as an `OutputDataset` + :returns: A Spark `StreamingQuery` if data is written in streaming, otherwise `None` + """ + if df.isStreaming: + if not output_dataset.trigger: + query = ( + df.writeStream.format(output_dataset.format) + .outputMode(output_dataset.output_mode) + .options(**output_dataset.options) + .start(output_dataset.location) + ) + else: + query = ( + df.writeStream.format(output_dataset.format) + .outputMode(output_dataset.output_mode) + .options(**output_dataset.options) + .trigger(**output_dataset.trigger) + .start(output_dataset.location) + ) + return query + + else: + ( + df.write.format(output_dataset.format) + .mode(output_dataset.output_mode) + .options(**output_dataset.options) + .save(output_dataset.location) + ) + + return None diff --git a/docs/source/index.rst b/docs/source/index.rst index 7e5eb4be..54d5389e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -37,6 +37,8 @@ As it is installable via `%pip install`, it can also be incorporated in environm Using multiple tables Extending text generation Use with Delta Live Tables + Writing Generated Data + Data Generation Specs from Configuration Troubleshooting data generation .. toctree:: diff --git a/docs/source/using_standard_datasets.rst b/docs/source/using_standard_datasets.rst index f5264f7b..8705f9a1 100644 --- a/docs/source/using_standard_datasets.rst +++ b/docs/source/using_standard_datasets.rst @@ -34,11 +34,11 @@ provide summary derived data from the data produced by the data generation insta In keeping with the varied roles of the associated datasets, the associated datasets may be retrieved using one of several methods to get the dataframes in a way that is self descriptive. The methods are: - - `Datasets.getAssociatedDataset()` - returns a dataframe based on the supplied parameters and provider logic - - `Datasets.getSupportingDataset()` - alias for `Datasets.getAssociatedDataset()` - - `Datasets.getCombinedDataset()` - alias for `Datasets.getAssociatedDataset()` - - `Datasets.getEnrichedDataset()` - alias for `Datasets.getAssociatedDataset()` - - `Datasets.getSummaryDataset()` - alias for `Datasets.getAssociatedDataset()` + - `Datasets.getAssociatedDataset()` - returns a dataframe based on the supplied parameters and provider logic + - `Datasets.getSupportingDataset()` - alias for `Datasets.getAssociatedDataset()` + - `Datasets.getCombinedDataset()` - alias for `Datasets.getAssociatedDataset()` + - `Datasets.getEnrichedDataset()` - alias for `Datasets.getAssociatedDataset()` + - `Datasets.getSummaryDataset()` - alias for `Datasets.getAssociatedDataset()` The method names are intended to be self descriptive and to provide a clear indication of the role of the associated usage, but they are all aliases of `getAssociatedDataset()` and can be used interchangeably. diff --git a/docs/source/writing_generated_data.rst b/docs/source/writing_generated_data.rst new file mode 100644 index 00000000..94e7b7fb --- /dev/null +++ b/docs/source/writing_generated_data.rst @@ -0,0 +1,118 @@ +.. Databricks Labs Data Generator documentation master file, created by + sphinx-quickstart on Sun Jun 21 10:54:30 2020. + +Writing Generated Data to Tables or Files +=========================================================== + +Generated data can be written directly to output tables or files using the ``OutputDataset`` class. + +Writing generated data to a table +--------------------------------- + +Once you've defined a ``DataGenerator``, call the ``saveAsDataset`` method to write data to a target table. + +.. code-block:: python + + import dbldatagen as dg + from dbldatagen.config import OutputDataset + + # Create a sample data generator with a few columns: + testDataSpec = ( + dg.DataGenerator(spark, name="users_dataset", rows=1000) + .withColumn("user_name", expr="concat('user_', id)") + .withColumn("email_address", expr="concat(user_name, '@email.com')") + .withColumn("phone_number", template="555-DDD-DDDD") + ) + + # Define an output configuration: + outputDataset = OutputDataset("main.demo.users") + + # Generate and write the output data: + testDataSpec.saveAsDataset(dataset=outputDataset) + +Writing generated data with streaming +------------------------------------- + +Specify a ``trigger`` to write output data using Structured Streaming. Triggers can be passed as +Python dictionaries (e.g. ``{"processingTime": "10 seconds"}`` to write data every 10 seconds). + +.. code-block:: python + + import dbldatagen as dg + from dbldatagen.config import OutputDataset + + # Create a sample data generator with a few columns: + testDataSpec = ( + dg.DataGenerator(spark, name="users_dataset", rows=1000) + .withColumn("user_name", expr="concat('user_', id)") + .withColumn("email_address", expr="concat(user_name, '@email.com')") + .withColumn("phone_number", template="555-DDD-DDDD") + ) + + # Define an output configuration: + outputDataset = OutputDataset( + "main.demo.table", + trigger={"processingTime": "10 seconds"} + ) + + # Generate and write the output data: + testDataSpec.saveAsDataset(dataset=outputDataset) + +Options for writing data +------------------------ + +Specify the ``output_mode`` and ``options`` to control how data is written to output tables or files. +Data will be written in append mode by default. + +.. code-block:: python + + import dbldatagen as dg + from dbldatagen.config import OutputDataset + + # Create a sample data generator with a few columns: + testDataSpec = ( + dg.DataGenerator(spark, name="users_dataset", rows=1000) + .withColumn("user_name", expr="concat('user_', id)") + .withColumn("email_address", expr="concat(user_name, '@email.com')") + .withColumn("phone_number", template="555-DDD-DDDD") + ) + + # Define an output configuration: + outputDataset = OutputDataset( + "/Volumes/main/demo/users_files/csv", + options={"mergeSchema": "true"}, + output_mode="overwrite" + ) + + # Generate and write the output data: + testDataSpec.saveAsDataset(dataset=outputDataset) + +Writing generated data to files +------------------------------- + +To write generated data to files (e.g. JSON or CSV), specify a ``format`` when creating your ``OutputConfig``. +File data can be written to a relative path using Databricks Volumes, an absolute path in cloud storage, or a path +in Databricks File System (DBFS). + +.. code-block:: python + + import dbldatagen as dg + from dbldatagen.config import OutputDataset + + # Create a sample data generator with a few columns: + testDataSpec = ( + dg.DataGenerator(spark, name="users_dataset", rows=1000) + .withColumn("user_name", expr="concat('user_', id)") + .withColumn("email_address", expr="concat(user_name, '@email.com')") + .withColumn("phone_number", template="555-DDD-DDDD") + ) + + # Define an output configuration: + outputDataset = OutputDataset( + "/Volumes/main/demo/users_files/csv", + format="csv", + options={"header": "true"} + ) + + # Generate and write the output data: + testDataSpec.saveAsDataset(dataset=outputDataset) diff --git a/tests/test_output.py b/tests/test_output.py new file mode 100644 index 00000000..3d41c50b --- /dev/null +++ b/tests/test_output.py @@ -0,0 +1,115 @@ +import os +import shutil +import time +import uuid +import pytest + +from pyspark.sql.types import IntegerType, StringType, FloatType + +import dbldatagen as dg + + +spark = dg.SparkSingleton.getLocalInstance("output tests") + + +class TestOutput: + @pytest.fixture + def get_output_directories(self): + base_dir = f"/tmp/testdatagenerator/{uuid.uuid4()}" + print("test dir created") + data_dir = os.path.join(base_dir, "data") + checkpoint_dir = os.path.join(base_dir, "checkpoint") + os.makedirs(data_dir) + os.makedirs(checkpoint_dir) + + print("\n\n*** Test directories", base_dir, data_dir, checkpoint_dir) + yield base_dir, data_dir, checkpoint_dir + + shutil.rmtree(base_dir, ignore_errors=True) + print(f"\n\n*** test dir [{base_dir}] deleted") + + @pytest.mark.parametrize("trigger", [{"availableNow": True}, {"once": True}, {"invalid": "yes"}]) + def test_initialize_output_dataset_invalid_trigger(self, trigger): + with pytest.raises(ValueError, match=f"Attribute 'trigger' must be a dictionary of the form"): + _ = dg.OutputDataset(location="/location", trigger=trigger) + + @pytest.mark.parametrize("seed_column_name, table_format", [("id", "parquet"), ("_id", "json"), ("id", "csv")]) + def test_build_output_data_batch(self, get_output_directories, seed_column_name, table_format): + base_dir, data_dir, checkpoint_dir = get_output_directories + table_dir = f"{data_dir}/{uuid.uuid4()}" + + gen = dg.DataGenerator( + sparkSession=spark, + name="test_data_set1", + rows=100, + partitions=4, + seedMethod='hash_fieldname', + seedColumnName=seed_column_name + ) + + gen = ( + gen + .withIdOutput() + .withColumn("code1", IntegerType(), minValue=100, maxValue=200) + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) + + output_dataset = dg.OutputDataset( + location=table_dir, + output_mode="append", + format=table_format, + options={"mergeSchema": "true"}, + ) + + gen.saveAsDataset(output_dataset) + persisted_df = spark.read.format(table_format).load(table_dir) + assert persisted_df.count() > 0 + + @pytest.mark.parametrize("seed_column_name, table_format", [("id", "parquet"), ("_id", "json"), ("id", "csv")]) + def test_build_output_data_streaming(self, get_output_directories, seed_column_name, table_format): + base_dir, data_dir, checkpoint_dir = get_output_directories + table_dir = f"{data_dir}/{uuid.uuid4()}" + + gen = dg.DataGenerator( + sparkSession=spark, + name="test_data_set1", + rows=100, + partitions=4, + seedMethod='hash_fieldname', + seedColumnName=seed_column_name + ) + + gen = ( + gen + .withIdOutput() + .withColumn("code1", IntegerType(), minValue=100, maxValue=200) + .withColumn("code2", IntegerType(), minValue=0, maxValue=10) + .withColumn("code3", StringType(), values=['a', 'b', 'c']) + .withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True) + .withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1]) + ) + + output_dataset = dg.OutputDataset( + location=table_dir, + output_mode="append", + format=table_format, + options={"mergeSchema": "true", "checkpointLocation": f"{data_dir}/{checkpoint_dir}"}, + trigger={"processingTime": "1 SECOND"} + ) + + query = gen.saveAsDataset(output_dataset, with_streaming=True) + + start_time = time.time() + elapsed_time = 0 + time_limit = 10.0 + + while elapsed_time < time_limit: + time.sleep(1) + elapsed_time = time.time() - start_time + + query.stop() + persisted_df = spark.read.format(table_format).load(table_dir) + assert persisted_df.count() > 0 diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 9ed97b9f..f65361cb 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -11,7 +11,7 @@ spark = dg.SparkSingleton.getLocalInstance("streaming tests") -class TestStreaming(): +class TestStreaming: row_count = 100000 column_count = 10 time_to_run = 10 From de1bb7539466a4353ef192d5625ecdef4b4fe2ba Mon Sep 17 00:00:00 2001 From: Greg Hansen <163584195+ghanse@users.noreply.github.com> Date: Tue, 7 Oct 2025 11:46:34 -0400 Subject: [PATCH 12/20] Format modules (#367) --- dbldatagen/data_generator.py | 2 +- dbldatagen/datasets/dataset_provider.py | 5 +- dbldatagen/datasets_object.py | 367 +++++++++++++++--------- dbldatagen/function_builder.py | 68 +++-- dbldatagen/html_utils.py | 63 ++-- pyproject.toml | 13 +- 6 files changed, 306 insertions(+), 212 deletions(-) diff --git a/dbldatagen/data_generator.py b/dbldatagen/data_generator.py index a08c5537..14fa92e0 100644 --- a/dbldatagen/data_generator.py +++ b/dbldatagen/data_generator.py @@ -1913,7 +1913,7 @@ def scriptMerge( result = "\n".join(results) if asHtml: - result = HtmlUtils.formatCodeAsHtml(results) + result = HtmlUtils.formatCodeAsHtml(result) return result diff --git a/dbldatagen/datasets/dataset_provider.py b/dbldatagen/datasets/dataset_provider.py index a6882c6c..d67c440d 100644 --- a/dbldatagen/datasets/dataset_provider.py +++ b/dbldatagen/datasets/dataset_provider.py @@ -20,6 +20,7 @@ This file defines the DatasetProvider class """ + class DatasetProvider(ABC): """ The DatasetProvider class acts as a base class for all dataset providers @@ -206,7 +207,7 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N raise NotImplementedError("Base data provider does not provide any table generation specifications!") @abstractmethod - def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, + def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str | None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator: """ Gets associated datasets that are used in conjunction with the provider datasets. @@ -240,7 +241,7 @@ def allowed_options(options: list[str]|None =None) -> Callable[[Callable], Calla def decorator(func: Callable) -> Callable: @functools.wraps(func) - def wrapper(*args, **kwargs) -> Callable: # noqa: ANN002 + def wrapper(*args, **kwargs) -> Callable: bad_options = [keyword_arg for keyword_arg in kwargs if keyword_arg not in DEFAULT_OPTIONS and keyword_arg not in options] diff --git a/dbldatagen/datasets_object.py b/dbldatagen/datasets_object.py index 602e68ff..067a37a3 100644 --- a/dbldatagen/datasets_object.py +++ b/dbldatagen/datasets_object.py @@ -19,16 +19,19 @@ import re +from pyspark.sql import SparkSession + +from dbldatagen.data_generator import DataGenerator from dbldatagen.datasets.dataset_provider import DatasetProvider from dbldatagen.spark_singleton import SparkSingleton from dbldatagen.utils import strip_margins -from dbldatagen.data_generator import DataGenerator class Datasets: - """This class is used to generate standard data sets based on a plugin provider model. + """ + This class is used to generate standard data sets based on a plugin provider model. - It allows for quick generation of data for common scenarios. + It allows for quick generation of data for common scenarios. :param sparkSession: Spark session instance to use when performing spark operations :param name: Dataset name to use @@ -43,20 +46,22 @@ class Datasets: If a dataset provider supports multiple tables, the name of the table to retrieve is passed to the `get` method, along with any parameters that are required to generate the data. - """ @classmethod - def getProviderDefinitions(cls, name=None, pattern=None, supportsStreaming=False): - """Get provider definitions for one or more datasets - - :param name: name of dataset to get provider for, if None, returns all providers - :param pattern: pattern to match dataset name, if None, returns all providers optionally matching name - :param supportsStreaming: If true, filters out dataset providers that don't support streaming - :return: list of provider definitions matching name and pattern - - Each entry will be of the form DatasetProvider.DatasetProviderDefinition + def getProviderDefinitions( + cls, + name: str | None = None, + pattern: str | None = None, + supportsStreaming: bool = False + ) -> list[DatasetProvider.DatasetDefinition]: + """ + Gets provider definitions for one or more datasets. + :param name: Name of dataset to get provider for, if ``None``, returns all providers. + :param pattern: Pattern to match dataset name, if ``None``, returns all providers optionally matching name. + :param supportsStreaming: If ``true``, filters out dataset providers that don't support streaming. + :returns: List of provider definitions matching input name and pattern. """ if pattern is not None and name is not None: summary_list = [provider_definition @@ -83,12 +88,12 @@ def getProviderDefinitions(cls, name=None, pattern=None, supportsStreaming=False return summary_list @classmethod - def list(cls, pattern=None, supportsStreaming=False): - """This method lists the registered datasets - It filters the list by a regular expression pattern if provided + def list(cls, pattern: str | None = None, supportsStreaming: bool = False) -> None: + """ + Lists registered datasets. Optionally filters the list by a regular expression pattern if provided. - :param pattern: Pattern to match dataset names. If None, all datasets are listed - :param supportsStreaming: if True, only return providerDefinitions that supportStreaming + :param pattern: Pattern to match dataset names. If ``None``, all datasets are listed. + :param supportsStreaming: If ``true``, filters out dataset providers that don't support streaming. """ summary_list = sorted([(providerDefinition.name, providerDefinition.summary) for providerDefinition in cls.getProviderDefinitions(name=None, pattern=pattern, @@ -100,11 +105,11 @@ def list(cls, pattern=None, supportsStreaming=False): print(f" Provider: `{entry[0]}` - Summary description: {entry[1]}") @classmethod - def describe(cls, name): - """This method lists the registered datasets - It filters the list by a regular expression pattern if provided + def describe(cls, name: str) -> None: + """ + Prints a description for the input dataset. - :param name: name of dataset to describe + :param name: Name of dataset to describe """ providers = cls.getProviderDefinitions(name=name) @@ -112,28 +117,30 @@ def describe(cls, name): providerDef = providers[0] - summaryAttributes = f""" - | Dataset Name: {providerDef.name} - | Summary: {providerDef.summary} - | Supports Streaming: {providerDef.supportsStreaming} - | Provides Table Generators: {providerDef.tables} - | Primary Table: {providerDef.primaryTable} - | Associated Datasets: {providerDef.associatedDatasets} - |""" + summaryAttributes = f""" + | Dataset Name: {providerDef.name} + | Summary: {providerDef.summary} + | Supports Streaming: {providerDef.supportsStreaming} + | Provides Table Generators: {providerDef.tables} + | Primary Table: {providerDef.primaryTable} + | Associated Datasets: {providerDef.associatedDatasets} + |""" print(f"The dataset '{providerDef.name}' is described as follows:") - print(strip_margins(summaryAttributes, '|')) + print(strip_margins(summaryAttributes, "|")) print("\n".join([x.strip() for x in providers[0].description.split("\n")])) print("") print("Detailed description:") print("") print(providerDef.description) - def __init__(self, sparkSession, name=None, streaming=False): - """ Constructor: - :param sparkSession: Spark session to use - :param name: name of dataset to search for - :param streaming: if True, validdates that dataset supports streaming data + def __init__(self, sparkSession: SparkSession, name: str | None = None, streaming: bool = False) -> None: + """ + Creates a ``Datasets`` object. + + :param sparkSession: ``SparkSession`` to use + :param name: Name of the dataset to search for + :param streaming: If True, validdates that dataset supports streaming """ if not sparkSession: sparkSession = SparkSingleton.getLocalInstance() @@ -147,11 +154,16 @@ def __init__(self, sparkSession, name=None, streaming=False): self._datasetsVersion = DatasetProvider.getRegisteredDatasetsVersion() self._navigator = None - def _getNavigator(self): + def _getNavigator(self) -> NavigatorNode: + """ + Gets a navigator for the current dataset. + + :returns: Navigator for the current dataset + """ latestVersion = DatasetProvider.getRegisteredDatasetsVersion() if self._datasetsVersion != latestVersion or not self._navigator: # create a navigator object to support x.y.z notation - root = self.NavigatorNode(self) + root = NavigatorNode(self) providersMap = DatasetProvider.getRegisteredDatasets() @@ -167,7 +179,18 @@ def _getNavigator(self): return self._navigator - def _getProviderInstanceAndMetadata(self, providerName, supportsStreaming): + def _getProviderInstanceAndMetadata( + self, + providerName: str, + supportsStreaming: bool + ) -> tuple[DatasetProvider, DatasetProvider.DatasetDefinition]: + """ + Gets a dataset provider and definition. + + :param providerName: Name of the dataset provider + :param supportsStreaming: If True, validdates that dataset supports streaming + :returns: Requested `DatasetProvider` and the associated `DatasetDefinition` + """ assert providerName is not None and len(providerName), "Dataset provider name must be supplied" providers = self.getProviderDefinitions(name=providerName, supportsStreaming=supportsStreaming) @@ -185,7 +208,19 @@ def _getProviderInstanceAndMetadata(self, providerName, supportsStreaming): return providerInstance, providerDefinition - def _get(self, *, providerName, tableName, rows=-1, partitions=-1, **kwargs): + def _get( + self, *, providerName: str, tableName: str | None, rows: int = -1, partitions: int = -1, **kwargs + ) -> DataGenerator: + """ + Gets a table generator from the dataset provider. + + :param providerName: Name of the dataset provider + :param tableName: Optional name of the table to get (if ``None``, the provider's primary table is returned) + :param rows: Optional number of rows to generate (if -1, provider should compute defaults) + :param partitions: Optional number of partitions to use (if -1, the number of partitions is computed automatically) + :param kwargs: Additional keyword arguments to pass to the provider + :returns: `DataGenerator` for the requested table + """ providerInstance, providerDefinition = \ self._getProviderInstanceAndMetadata(providerName, supportsStreaming=self._streamingRequired) @@ -197,20 +232,24 @@ def _get(self, *, providerName, tableName, rows=-1, partitions=-1, **kwargs): if tableName not in providerDefinition.tables: raise ValueError(f"Table `{tableName}` not a recognized table option") - tableDefn = providerInstance.getTableGenerator(self._sparkSession, tableName=tableName, rows=rows, - partitions=partitions, - **kwargs) - return tableDefn + return providerInstance.getTableGenerator( + self._sparkSession, + tableName=tableName, + rows=rows, + partitions=partitions, + **kwargs + ) - def get(self, table=None, rows=-1, partitions=-1, **kwargs): - """Get a table generator from the dataset provider + def get(self, table: str | None = None, rows: int = -1, partitions: int = -1, **kwargs) -> DataGenerator: + """ + Gets a table generator from the dataset provider. - These are DataGenerator instances that can be used to generate the data. - The dataset providers also optionally can provide supporting tables which are computed tables based on - parameters. These are retrieved using the `getAssociatedDataset` method + These are `DataGenerator` instances that can be used to generate the data. Dataset providers also optionally + can provide supporting tables which are computed tables based on parameters. These are retrieved using the + `getAssociatedDataset` method If the dataset supports multiple tables, the table may be specified in the `table` parameter. - If none is specified, the primary table is used. + If ``None`` is specified, a generator for the primary table is returned. If `rows` or `partitions` are not specified, default values are supplied by the provider. @@ -218,21 +257,31 @@ def get(self, table=None, rows=-1, partitions=-1, **kwargs): be optionally supplied. Additionally, for multi-table datasets, the table name must be one of the tables supported by the provider. - Default number of rows for multi-table datasets may differ - for example a 'customers' table may have a - 100,000 rows while a 'sales' table may have 1,000,000 rows. + THe default number of rows for each table in a multi-table dataset may vary - for example a 'customers' table + may have 100,000 rows while a 'sales' table may have 1,000,000 rows. - :param table: name of table to retrieve - :param rows: number of rows to generate. if -1, provider should compute defaults. - :param partitions: number of partitions to use.If -1, the number of partitions is computed automatically - table size and partitioning.If applied to a dataset with only a single table, this is ignored. - :param kwargs: additional keyword arguments to pass to the provider - :returns: table generator + :param table: Name of table to retrieve + :param rows: Optional number of rows to generate (if -1, provider should compute defaults) + :param partitions: Optional number of partitions to use (if -1, the number of partitions is computed automatically) + :param kwargs: Additional keyword arguments to pass to the provider + :returns: `DataGenerator` for the requested table """ - return self._get(providerName=self._name, tableName=table, rows=rows, partitions=partitions, - **kwargs) + return self._get(providerName=self._name, tableName=table, rows=rows, partitions=partitions, **kwargs) - def _getSupportingTable(self, *, providerName, tableName, rows=-1, partitions=-1, **kwargs): + def _getSupportingTable( + self, *, providerName: str, tableName: str, rows: int = -1, partitions: int = -1, **kwargs + ) -> DataGenerator: + """ + Gets a supporting table needed to build a multi-table dataset. + + :param providerName: Name of the dataset provider + :param tableName: Name of the table to get from the provider + :param rows: Optional number of rows to generate (if -1, provider should compute defaults) + :param partitions: Optional number of partitions to use (if -1, the number of partitions is computed automatically) + :param kwargs: Additional keyword arguments to pass to the provider + :returns: `DataGenerator` for the supporting table + """ providerInstance, providerDefinition = self._getProviderInstanceAndMetadata( providerName, supportsStreaming=self._streamingRequired ) @@ -247,7 +296,7 @@ def _getSupportingTable(self, *, providerName, tableName, rows=-1, partitions=-1 ) return dfSupportingTable - def getAssociatedDataset(self, *, table, rows=-1, partitions=-1, **kwargs) -> DataGenerator: + def getAssociatedDataset(self, *, table: str, rows: int = -1, partitions: int = -1, **kwargs) -> DataGenerator: """ Gets a table generator from the dataset provider. @@ -275,8 +324,9 @@ def getAssociatedDataset(self, *, table, rows=-1, partitions=-1, **kwargs) -> Da .. note :: This method may also be invoked via the aliased names - `getSupportingDataset` and `getCombinedDataset` """ - return self._getSupportingTable(providerName=self._name, tableName=table, rows=rows, partitions=partitions, - **kwargs) + return self._getSupportingTable( + providerName=self._name, tableName=table, rows=rows, partitions=partitions, **kwargs + ) # aliases @@ -288,7 +338,7 @@ def getAssociatedDataset(self, *, table, rows=-1, partitions=-1, **kwargs) -> Da getSummaryDataset = getAssociatedDataset getEnrichedDataset = getAssociatedDataset - def __getattr__(self, path): + def __getattr__(self, path: str) -> NavigatorNode: assert path is not None, "path should be non-empty" navigator = self._getNavigator() @@ -307,98 +357,139 @@ def __getattr__(self, path): return navigator - class NavigatorNode: - """Dataset Navigator class for navigating datasets - This class is used to navigate datasets and their tables via dotted notation. +class NavigatorNode: + """ + This class is used to navigate datasets and their tables via dotted notation (i.e. + ``X.dataset_grouping.dataset.table`` where X is an instance of the dataset navigator.) - Ie X.dataset_grouping.dataset.table where X is an intance of the dataset navigator. + The navigator is initialized with a set of paths and objects (usually providers) that are registered with the + `DatasetProvider` class. - The navigator is initialized with a set of paths and objects (usually providers) that are registered with the - DatasetProvider class. + When accessed via dotted notation, the navigator will use the `pathSegment` to locate the provider and create + it. - When accessed via dotted notation, the navigator will use the pathSegment to locate the provider and create it. + Any remaining `pathSegment` traversed will be used to locate the table within the provider. - Any remaining pathSegment traversed will be used to locate the table within the provider. + This provides a syntactic layering over the creation of the provider instance and table generation. - Overall, this just provides a syntactic layering over the creation of the provider instance - and table generation. + :param datasets: `Datasets` object + :param providerName: Dataset provider name for the node + :param tableName: Table name for the node + :param location: Location for the node - used in error reporting + """ + def __init__( + self, + datasets: Datasets, + providerName: str | None = None, + tableName: str | None = None, + location: list[str] | None = None + ) -> None: + """ Initialization for node + + :param datasets: instance of datasets object + :param providerName: provider name for node + :param tableName: table name for node + :param location: location for node - used in error reporting """ + self._datasets = datasets + self._children = None + self._providerName = providerName + self._tableName = tableName + self._location = location # expected to be a list of the attributes used to navigate to the node + + def __repr__(self) -> str: + return f"Node: (datasets: {self._datasets}, provider: {self._providerName}, loc: {self._location} )" + + def _addEntry( + self, + datasets: Datasets, + steps: list[str] | None, + providerName: str | None, + tableName: str | None + ) -> NavigatorNode: + """ + Adds an entry to the dataset navigator. - def __init__(self, datasets, providerName=None, tableName=None, location=None): - """ Initialization for node - - :param datasets: instance of datasets object - :param providerName: provider name for node - :param tableName: table name for node - :param location: location for node - used in error reporting - """ - self._datasets = datasets - self._children = None - self._providerName = providerName + :param datasets: `Datasets` object + :steps: List of steps to add + :param providerName: Dataset provider name for the node + :param tableName: Table name for the node + :returns: `NavigatorNode` with the steps added + """ + results = self + if steps is None or len(steps) == 0: self._tableName = tableName - self._location = location # expected to be a list of the attributes used to navigate to the node - - def __repr__(self): - return f"Node: (datasets: {self._datasets}, provider: {self._providerName}, loc: {self._location} )" - - def _addEntry(self, datasets, steps, providerName, tableName): - - results = self - if steps is None or len(steps) == 0: - self._tableName = tableName - self._providerName = providerName - else: - new_location = self._location + [steps[0]] if self._location is not None else [steps[0]] - if self._children is None: # no children exist - newNode = datasets.NavigatorNode(datasets, location=new_location) - self._children = {steps[0]: newNode._addEntry(datasets, steps[1:], providerName, tableName)} - elif steps[0] in self._children: # step is in the child dictionary - self._children[steps[0]]._addEntry(datasets, steps[1:], providerName, tableName) - else: # step is not in the child dictionary - newNode = datasets.NavigatorNode(datasets, location=new_location) - self._children[steps[0]] = newNode._addEntry(datasets, steps[1:], providerName, tableName) + self._providerName = providerName + else: + new_location = [*self._location, steps[0]] if self._location is not None else [steps[0]] + if self._children is None: # no children exist + newNode = NavigatorNode(datasets, location=new_location) + self._children = {steps[0]: newNode._addEntry(datasets, steps[1:], providerName, tableName)} + elif steps[0] in self._children: # step is in the child dictionary + self._children[steps[0]]._addEntry(datasets, steps[1:], providerName, tableName) + else: # step is not in the child dictionary + newNode = NavigatorNode(datasets, location=new_location) + self._children[steps[0]] = newNode._addEntry(datasets, steps[1:], providerName, tableName) + + return results + + def addEntry(self, datasets: Datasets, providerName: str, tableName: str | None = None) -> None: + """ + Adds an entry to the dataset navigator. - return results + :param datasets: `Datasets` object + :param providerName: Dataset provider name for the node + :param tableName: Table name for the node + """ + provider_steps = [x.strip() for x in providerName.split("/") if x is not None and len(x) > 0] - def addEntry(self, datasets, providerName, tableName): - provider_steps = [x.strip() for x in providerName.split("/") if x is not None and len(x) > 0] + self._addEntry(datasets, provider_steps, providerName, tableName) + # add an entry allowing navigation of the form `Datasets("basic").user()` with addition of table name + if tableName is not None: + provider_steps.append(tableName) self._addEntry(datasets, provider_steps, providerName, tableName) - # add an entry allowing navigation of the form `Datasets("basic").user()` with addition of table name - if tableName is not None: - provider_steps.append(tableName) - self._addEntry(datasets, provider_steps, providerName, tableName) + def find(self, attributePath: str) -> NavigatorNode: + """ + Gets a `NavigatorNode` with the provided path. + + :param attributePath: Attribute path to search + :returns: `NavigatorNode` with the provided path + """ + provider_steps = [x.strip() for x in attributePath.split("/") if x is not None and len(x) > 0] - def find(self, attributePath): - provider_steps = [x.strip() for x in attributePath.split("/") if x is not None and len(x) > 0] + node = self + for step in provider_steps: + if node._children is not None and step in node._children: + node = node._children[step] + else: + node = None + return node - node = self - for step in provider_steps: - if node._children is not None and step in node._children: - node = node._children[step] - else: - node = None - return node + def isFinal(self) -> bool: + """ + Checks if the navigator has a named dataset provider. - def isFinal(self): - return self._providerName is not None + :returns: `True` if the navigator has a named dataset provider + """ + return self._providerName is not None - def __getattr__(self, path): - node = self.find(path) + def __getattr__(self, path: str) -> NavigatorNode: + node = self.find(path) - if node is None: - location_path = ".".join(self._location) + "." + path - raise ValueError(f"Provider / table not found {path} in sequence `{location_path}`") - return node + if node is None: + location_path = ".".join(self._location) + "." + path + raise ValueError(f"Provider / table not found {path} in sequence `{location_path}`") + return node - def __call__(self, *args, **kwargs): - if not self.isFinal(): - raise ValueError(f"Cant resolve provider / table name for sequence {self._location}") + def __call__(self, *args, **kwargs) -> DataGenerator: + if not self.isFinal(): + raise ValueError(f"Cant resolve provider / table name for sequence {self._location}") - if self._tableName is not None: - return self._datasets._get(*args, providerName=self._providerName, tableName=self._tableName, **kwargs) - else: - return self._datasets._get(*args, providerName=self._providerName, **kwargs) + if self._tableName is not None: + return self._datasets._get(*args, providerName=self._providerName, tableName=self._tableName, **kwargs) + else: + return self._datasets._get(*args, providerName=self._providerName, **kwargs) diff --git a/dbldatagen/function_builder.py b/dbldatagen/function_builder.py index f4d68889..0203099d 100644 --- a/dbldatagen/function_builder.py +++ b/dbldatagen/function_builder.py @@ -1,30 +1,46 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +""" +This file defines the `ColumnGeneratorBuilder` class and utility functions +""" + import itertools +from typing import Any -from pyspark.sql.types import StringType, DateType, TimestampType +from pyspark.sql.types import DataType, DateType, StringType, TimestampType class ColumnGeneratorBuilder: - """ Helper class to build functional column generators of specific forms""" + """ + Helper class to build functional column generators of specific forms + """ @classmethod - def _mkList(cls, x): + def _mkList(cls, x: object) -> list: + """ + Makes a list of the supplied object instance if it is not already a list. + + :param x: Input object to process + :returns: List containing the supplied object if it is not already a list; otherwise returns the object """ - Makes a list of the supplied object instance if it is not already a list - :param x: object to process - :returns: Returns list of supplied object if it is not already a list, otherwise simply returns the object""" return [x] if type(x) is not list else x @classmethod - def _lastElement(cls, x): - """ Gets the last element, if the object is a list otherwise returns the object itself""" - return x[-1] if type(x) is list else x + def _lastElement(cls, x: object) -> object: + """ + Gets the last element from the supplied object if it is a list. + + :param x: Input object + :returns: Last element of the input object if it is a list; otherwise returns the object + """ + return x[-1] if isinstance(x, list) else x @classmethod - def _mkCdfProbabilities(cls, weights): - """ make cumulative distribution function probabilities for each value in values list + def _mkCdfProbabilities(cls, weights: list[float]) -> list[float]: + """ + Makes cumulative distribution function probabilities for each value in values list. a cumulative distribution function for discrete values can uses a table of cumulative probabilities to evaluate different expressions @@ -46,6 +62,9 @@ def _mkCdfProbabilities(cls, weights): while datasets of size 10,000 x `number of values` gives a repeated distribution within 5% of expected distribution. + :param weights: List of weights to compute CDF probabilities for + :returns: List of CDF probabilities + Example code to be generated (pseudo code):: # given values value1 .. valueN, prob1 to probN @@ -61,13 +80,12 @@ def _mkCdfProbabilities(cls, weights): """ total_weights = sum(weights) - return list(map(lambda x: x / total_weights, itertools.accumulate(weights))) + return [x / total_weights for x in itertools.accumulate(weights)] @classmethod - def mkExprChoicesFn(cls, values, weights, seed_column, datatype): - """ Create SQL expression to compute the weighted values expression - - build an expression of the form:: + def mkExprChoicesFn(cls, values: list[Any], weights: list[float], seed_column: str, datatype: DataType) -> str: + """ + Creates a SQL expression to compute a weighted values expression. Builds an expression of the form:: case when rnd_column <= weight1 then value1 @@ -77,22 +95,22 @@ def mkExprChoicesFn(cls, values, weights, seed_column, datatype): else valueN end - based on computed probability distribution for values. - - In Python 3.6 onwards, we could use the choices function but this python version is not - guaranteed on all Databricks distributions + The output expression is based on the computed probability distribution for the specified values. - :param values: list of values - :param weights: list of weights - :param seed_column: base column for expression - :param datatype: data type of function return value + In Python 3.6 onwards, we could use the choices function but this python version is not guaranteed on all + Databricks distributions. + :param values: List of values + :param weights: List of weights + :param seed_column: Base column name for expression + :param datatype: Spark `DataType` of the output expression + :returns: SQL expression representing the weighted values """ cdf_probs = cls._mkCdfProbabilities(weights) output = [" CASE "] - conditions = zip(values, cdf_probs) + conditions = zip(values, cdf_probs, strict=False) for v, cdf in conditions: # TODO(alex): single quotes needs to be escaped diff --git a/dbldatagen/html_utils.py b/dbldatagen/html_utils.py index 326d92c6..de8e2526 100644 --- a/dbldatagen/html_utils.py +++ b/dbldatagen/html_utils.py @@ -6,40 +6,40 @@ This file defines the `HtmlUtils` classes and utility functions """ -from .utils import system_time_millis +from dbldatagen.utils import system_time_millis class HtmlUtils: - """ Utility class for formatting code as HTML and other notebook related formatting - + """ + Utility class for formatting code as HTML and other notebook-related formatting. """ - def __init__(self): + def __init__(self) -> None: pass - @classmethod - def formatCodeAsHtml(cls, codeText): - """ Formats supplied code as Html suitable for use with notebook ``displayHTML`` - - :param codeText: Code to be wrapped in html section - :return: Html string + @staticmethod + def formatCodeAsHtml(codeText: str) -> str: + """ + Formats the input code as HTML suitable for use with a notebook's ``displayHTML`` command. - This will wrap the code with a html section using html ``pre`` and ``code`` tags. + This method wraps the input code with an html section using ``pre`` and ``code`` tags. It adds a *Copy Text to + Clipboard* button which allows users to easily copy the code to the clipboard. - It adds a copy text to clipboard button to enable users to easily copy the code to the clipboard. + Code is not reformatted. Supplied code should be preformatted into lines. - It does not reformat code so supplied code should be preformatted into lines. + :param codeText: Input code as a string + :return: Formatted code as an HTML string .. note:: As the notebook environment uses IFrames in rendering html within ``displayHtml``, it cannot use the newer ``navigator`` based functionality as this is blocked for cross domain IFrames by default. """ - ts = system_time_millis() + current_ts = system_time_millis() - formattedCode = f""" + return f"""

Generated Code

-

 
+            


               {codeText}
             


@@ -48,7 +48,7 @@ def formatCodeAsHtml(cls, codeText): function dbldatagen_copy_code_to_clipboard() {{ try {{ var r = document.createRange(); - r.selectNode(document.getElementById("generated_code_{ts}")); + r.selectNode(document.getElementById("generated_code_{current_ts}")); window.getSelection().removeAllRanges(); window.getSelection().addRange(r); document.execCommand('copy'); @@ -61,23 +61,20 @@ def formatCodeAsHtml(cls, codeText): """ - return formattedCode - - @classmethod - def formatTextAsHtml(cls, textContent, title="Output"): - """ Formats supplied text as Html suitable for use with notebook ``displayHTML`` - - :param textContent: Text to be wrapped in html section - :param title: Title text to be used - :return: Html string - - This will wrap the text content with with Html formatting + @staticmethod + def formatTextAsHtml(textContent: str, title: str = "Output") -> str: + """ + Formats the input text as HTML suitable for use with a notebook's ``displayHTML`` command. This wraps the text + content with HTML formatting blocks and adds a section title. + :param textContent: Input text to be wrapped in an HTML section + :param title: Section title (default `"Output"`) + :return: Text section as an HTML string """ - ts = system_time_millis() - formattedContent = f""" + current_ts = system_time_millis() + return f"""

{title}

-

 
+            

               {textContent}
             


@@ -86,7 +83,7 @@ def formatTextAsHtml(cls, textContent, title="Output"): function dbldatagen_copy_to_clipboard() {{ try {{ var r = document.createRange(); - r.selectNode(document.getElementById("generated_content_{ts}")); + r.selectNode(document.getElementById("generated_content_{current_ts}")); window.getSelection().removeAllRanges(); window.getSelection().addRange(r); document.execCommand('copy'); @@ -98,5 +95,3 @@ def formatTextAsHtml(cls, textContent, title="Output"): }} """ - - return formattedContent diff --git a/pyproject.toml b/pyproject.toml index 99be0820..edffc15e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -159,10 +159,7 @@ exclude = [ "dbldatagen/data_analyzer.py", "dbldatagen/datagen_constants.py", "dbldatagen/datarange.py", - "dbldatagen/datasets_object.py", "dbldatagen/daterange.py", - "dbldatagen/function_builder.py", - "dbldatagen/html_utils.py", "dbldatagen/nrange.py", "dbldatagen/schema_parser.py", "dbldatagen/text_generator_plugins.py", @@ -198,6 +195,7 @@ ignore = [ "SIM102", # Use a single if-statement "SIM108", # Use ternary operator "UP007", # Use X | Y for type annotations (keep Union for compatibility) + "ANN002", # Missing type annotation for *args "ANN003", # Missing type annotation for **kwargs ] @@ -242,10 +240,7 @@ ignore = [ "dbldatagen/data_analyzer.py", "dbldatagen/datagen_constants.py", "dbldatagen/datarange.py", - "dbldatagen/datasets_object.py", "dbldatagen/daterange.py", - "dbldatagen/function_builder.py", - "dbldatagen/html_utils.py", "dbldatagen/nrange.py", "dbldatagen/schema_parser.py", "dbldatagen/serialization.py", @@ -281,10 +276,7 @@ ignore-paths = [ "dbldatagen/data_generator.py", "dbldatagen/datagen_constants.py", "dbldatagen/datarange.py", - "dbldatagen/datasets_object.py", "dbldatagen/daterange.py", - "dbldatagen/function_builder.py", - "dbldatagen/html_utils.py", "dbldatagen/nrange.py", "dbldatagen/schema_parser.py", "dbldatagen/serialization.py", @@ -414,10 +406,7 @@ exclude = [ "dbldatagen/data_analyzer.py", "dbldatagen/datagen_constants.py", "dbldatagen/datarange.py", - "dbldatagen/datasets_object.py", "dbldatagen/daterange.py", - "dbldatagen/function_builder.py", - "dbldatagen/html_utils.py", "dbldatagen/nrange.py", "dbldatagen/schema_parser.py", "dbldatagen/serialization.py", From 00da1cd43943c8fc4603af45254c649069719f0b Mon Sep 17 00:00:00 2001 From: brettaulbaugh-db Date: Tue, 14 Oct 2025 18:14:04 -0500 Subject: [PATCH 13/20] Example notebook for Oil and Gas industry (#363) * Add files via upload Adding an example notebook on how to leverage DBLDATAGEN in the creation of oil and gas datasets/. Currently placing this in the examples/notebooks file * Update and rename [DBLDATAGEN]Oil&GasWellHeaderDailyProductionTypeCurve.py to oil_gas_data_generation.py implemented changes Greg suggested aorund function formatting * Update oil_gas_data_generation.py added documentaiton link and install clarification --------- Co-authored-by: Greg Hansen <163584195+ghanse@users.noreply.github.com> --- examples/notebooks/oil_gas_data_generation.py | 518 ++++++++++++++++++ 1 file changed, 518 insertions(+) create mode 100644 examples/notebooks/oil_gas_data_generation.py diff --git a/examples/notebooks/oil_gas_data_generation.py b/examples/notebooks/oil_gas_data_generation.py new file mode 100644 index 00000000..511c6b07 --- /dev/null +++ b/examples/notebooks/oil_gas_data_generation.py @@ -0,0 +1,518 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Synthetic Oil & Gas Data Generation with DBLDATAGEN +# MAGIC +# MAGIC Welcome to this tutorial notebook! Here, you'll learn how to use the **[DBLDATAGEN](https://databrickslabs.github.io/dbldatagen/public_docs/APIDOCS.html)** library to generate realistic synthetic datasets for oil & gas analytics. The Daily Production dataset is a fundamental analytical product for all upstream operators and this notebook walks through the creation of this dataset using DBLDATAGEN. +# MAGIC +# MAGIC --- +# MAGIC +# MAGIC ## What You'll Learn +# MAGIC +# MAGIC - **Define Data Generators for well header, daily production, and type curve dataset based on ARPS decline curve parameters** for multiple formations +# MAGIC - **Generate well header, daily production, and type curve forecast data** +# MAGIC - **Visualize and analyze synthetic production data** +# MAGIC +# MAGIC --- +# MAGIC +# MAGIC This notebook is designed for **petroleum engineers**, **data scientists**, and **analytics professionals** who want to quickly create and experiment with realistic E&P datasets in Databricks. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Section: Install DBLDATAGEN +# MAGIC +# MAGIC This section ensures the **DBLDATAGEN** library is installed in your Databricks environment. +# MAGIC +# MAGIC > **DBLDATAGEN** is a powerful tool for generating large-scale synthetic datasets, ideal for testing, prototyping, and analytics development in oil & gas. dbldatagen can be installed using pip install commands, as a cluster-scoped library, or as a serverless environment-scoped library. +# MAGIC +# MAGIC --- + +# COMMAND ---------- + +# DBTITLE 1,Install DBLDATAGEN +# MAGIC %pip install dbldatagen + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Section: Import Libraries +# MAGIC +# MAGIC This section imports all necessary Python libraries: +# MAGIC +# MAGIC - **PySpark**: Distributed data processing +# MAGIC - **DBLDATAGEN**: Synthetic data generation +# MAGIC - **Pandas**: Data manipulation +# MAGIC - **Matplotlib** & **Seaborn**: Data visualization +# MAGIC +# MAGIC These libraries are essential for data generation, manipulation, and visualization throughout the notebook. +# MAGIC +# MAGIC --- + +# COMMAND ---------- + +# DBTITLE 1,Import required packages +from pyspark.sql.functions import col +from datetime import date, timedelta +import dbldatagen as dg +import pandas as pd +import matplotlib.pyplot as plt +import random +import seaborn as sns + + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Section: Data Generation Functions +# MAGIC +# MAGIC This section defines **reusable functions** for generating synthetic well header data, daily production profiles, and type curve forecasts. +# MAGIC +# MAGIC - Each function is **modular** and **parameterized**, allowing you to easily adjust: +# MAGIC - Number of wells +# MAGIC - Decline curve parameters (ARPS methodology) +# MAGIC - Other simulation settings +# MAGIC +# MAGIC > ⚙️ **Tip:** Adjust parameters to simulate different reservoir scenarios and production behaviors. +# MAGIC +# MAGIC --- + +# COMMAND ---------- + +# DBTITLE 1,Well Header Data Generation Function +def generate_well_header_data( + spark_session: "pyspark.sql.SparkSession", + formation: str, + q_i: float, + d: float, + b_factor: float, + num_assets: int = 1000, + partitions: int = 4, + randomSeed: int = None +) -> dg.DataGenerator: + """ + Creates a DataFrame with synthetic well header data. + + Args: + spark_session: Current SparkSession + formation: Producing formation name + q_i: Initial production rate (BOPD) + d: Initial decline rate + b_factor: ARPS b-factor + num_assets: Number of wells to generate + partitions: Number of partitions for parallelism (optional) + randomSeed: Random seed for reproducibility (optional) + + Returns: + A Spark DataFrame with synthetic well header data + """ + row_count = num_assets + partitions_requested = partitions + if randomSeed is None: + randomSeed = int(random.uniform(20, 1000)) # Random seed for reproducibility + + data_spec = ( + dg.DataGenerator( + sparkSession=spark_session, + name=formation, + rows=row_count, + partitions=partitions_requested, + randomSeed=randomSeed, + ) + .withColumn( + "API_NUMBER", + "bigInt", + minValue=42000000000000, + maxValue=42999999999999, + random=True ) + .withColumn( + "FIELD_NAME", + "string", + values=["Field_1", "Field_2", "Field_3", "Field_4", "Field_5"], + random=True, + ) + .withColumn( + "LATITUDE", + "float", + minValue=31.00, + maxValue=32.50, + step=1e-6, + random=True, + ) + .withColumn( + "LONGITUDE", + "float", + minValue=-104.00, + maxValue=-101.00, + step=1e-6, + random=True, + ) + .withColumn( + "COUNTY", + "string", + values=["Reeves", "Midland", "Ector", "Loving", "Ward"], + random=True, + ) + .withColumn("STATE", "string", values=["Texas"]) + .withColumn("COUNTRY", "string", values=["USA"]) + .withColumn( + "WELL_TYPE", + "string", + values=["Oil"], + random=True, + ) + .withColumn( + "WELL_ORIENTATION", + "string", + values=["Horizontal"], + random=True, + ) + .withColumn( + "PRODUCING_FORMATION", + "string", + values=[formation], + random=True, + ) + .withColumn( + "CURRENT_STATUS", + "string", + values=["Producing", "Shut-in", "Plugged and Abandoned", "Planned"], + random=True, + weights=[80, 10, 5, 5], + ) + .withColumn( + "TOTAL_DEPTH", "integer", minValue=12000, maxValue=20000, random=True + ) + .withColumn( + "SPUD_DATE", "date", begin="2020-01-01", end="2025-02-14", random=True + ) + .withColumn( + "COMPLETION_DATE", + "date", + begin="2020-01-01", + end="2025-02-14", + random=True, + ) + .withColumn( + "SURFACE_CASING_DEPTH", + "integer", + minValue=500, + maxValue=800, + random=True, + ) + .withColumn("OPERATOR_NAME", "string", values=["OPERATOR_XYZ"]) + .withColumn( + "PERMIT_DATE", "date", begin="2019-01-01", end="2025-02-14", random=True + ) + .withColumn( + "q_i", + "double", + values=[q_i] + ) + .withColumn( + "d", + "double", + values=[d] + ) + .withColumn( + "b", + "double", + values=[b_factor] + ) + ) + + return data_spec.build() + +# COMMAND ---------- + +# DBTITLE 1,Daily Production Data Generation Function +def generate_daily_production( + spark_session: "pyspark.sql.SparkSession", + well_num: int, + q_i: float, + d: float, + b_factor: float, + q_i_multiplier: float, + partitions: int = 4, + randomSeed: int = None +) -> "pyspark.sql.DataFrame": + """ + Creates a DataFrame with daily production data. + + Args: + spark_session: Current SparkSession + well_num: Well number + q_i: Initial production rate + d: Initial decline rate + b_factor: ARPS b-factor + q_i_multiplier: Initial production rate multiplier to randomness + partitions: Number of partitions for parallelism (optional) + randomSeed: Random seed for reproducibility (optional) + + Returns: + A Spark DataFrame with daily production data + """ + # Randomly determine the number of days to generate for this well (between 100 and 700) + days_to_generate = int(round(random.uniform(100, 700))) + if randomSeed is None: + randomSeed = int(round(random.uniform(20, 1000), 0)) + data_gen = ( + dg.DataGenerator( + sparkSession=spark_session, + name="type_curve", + rows=days_to_generate, + partitions=partitions, + randomSeed=randomSeed, + ) + # Assign the unique well number (API or identifier) to all rows + .withColumn("well_num", "bigInt", values=[well_num]) + # Generate the day index from first production (1 to 1000, but only as many as days_to_generate) + .withColumn("day_from_first_production", "integer", minValue=1, maxValue=1000) + # Set the first production date as today minus the number of days generated + .withColumn( + "first_production_date", + "date", + values=[date.today() - timedelta(days=days_to_generate)], + ) + # Calculate the actual date for each row by adding the day offset to the first production date + .withColumn( + "date", + "date", + expr="date_add(first_production_date, day_from_first_production)", + ) + # Assign ARPS initial production rate (q_i) for this well + .withColumn("q_i", "double", values=[q_i]) + # Assign ARPS initial decline rate (d) for this well + .withColumn("d", "double", values=[d]) + # Assign ARPS b-factor for this well + .withColumn("b", "double", values=[b_factor]) + # Introduce a multiplier to q_i to simulate rare production shut-ins (mostly 1.0, sometimes 0) + .withColumn( + "q_i_multiplier", + "double", + values=[q_i_multiplier, 0], + weights=[ + 97, + 3, + ], # 97% chance of normal production, 3% chance of zero (shut-in) + random=True, + ) + # Add a small random variation to production to simulate measurement noise or operational variability + .withColumn("variation", "double", expr="rand() * 0.1 + 0.95") + # Calculate actual oil production (BOPD) using the ARPS decline curve formula with all modifiers + .withColumn( + "actuals_bopd", + "double", + baseColumn=["q_i", "d", "b", "q_i_multiplier", "variation"], + expr="(q_i * q_i_multiplier) / power(1 + b * d * variation * day_from_first_production, 1/b)", + ) + ) + # Build and return the synthetic daily production DataFrame + return data_gen.build() + +# COMMAND ---------- + +# DBTITLE 1,Type Curve Data Generation Function +def generate_type_curve_forecast( + spark_session: "pyspark.sql.SparkSession", + formation: str, + q_i: float, + d: float, + b_factor: float, + partitions: int = 4, + randomSeed: int = None +) -> "pyspark.sql.DataFrame": + """ + Creates a DataFrame with type curve forecast data. + + Args: + spark_session: Current SparkSession + formation: Formation name + q_i: Initial production rate (BOPD) + d: Initial decline rate + b_factor: ARPS b-factor + partitions: Number of partitions for parallelism (optional) + randomSeed: Random seed for reproducibility (optional) + + Returns: + A Spark DataFrame with type curve forecast data + """ + days_to_generate = 2000 # Number of days to forecast in the type curve + + if randomSeed is None: + randomSeed = int(round(random.uniform(20, 1000), 0)) # Set random seed if not provided + + data_gen = ( + dg.DataGenerator( + sparkSession=spark_session, + name="type_curve", + rows=days_to_generate, + partitions=partitions, + randomSeed=randomSeed, + ) + # Add formation name as a column + .withColumn("formation", "STRING", values=[formation]) + # Generate day index for forecast (1 to 1000) + .withColumn("day_from_first_production", "integer", minValue=1, maxValue=1000) + # Assign ARPS parameters to all rows + .withColumn("q_i", "double", values=[q_i]) + .withColumn("d", "double", values=[d]) + .withColumn("b", "double", values=[b_factor]) + # Add small random variation to simulate operational variability + .withColumn("variation", "double", expr="rand() * 0.1 + 0.95") + # Calculate forecasted BOPD using ARPS decline curve formula + .withColumn( + "forecast_bopd", + "double", + baseColumn=["q_i","d","b","day_from_first_production"], + expr="(q_i ) / power(1 + b * d * day_from_first_production, 1/b)", + ) + ) + return data_gen.build() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Section: Generate Synthetic Data +# MAGIC +# MAGIC This section brings together all previous components to generate synthetic well header data, daily production profiles, and type curve forecasts for multiple formations. +# MAGIC +# MAGIC **In this section, you will:** +# MAGIC - Define ARPS decline curve parameters for each formation +# MAGIC - Generate and display well header and production data +# MAGIC - Create type curve forecasts +# MAGIC +# MAGIC This is the core of the tutorial, demonstrating the full workflow from parameter definition to data generation and visualization. + +# COMMAND ---------- + +# DBTITLE 1,Data Generation Implementation +# Define type curve parameters for each formation using ARPS decline curve methodology +type_curve_dict = { + "FORMATION_A": {"q_i": 6000, "d": 0.01, "b_factor": 0.8}, + "FORMATION_B": {"q_i": 7000, "d": 0.011, "b_factor": 0.7}, + "FORMATION_C": {"q_i": 5500, "d": 0.009, "b_factor": 0.8}, + "FORMATION_D": {"q_i": 5750, "d": 0.011, "b_factor": 0.7}, +} + +# Generate well header data for each formation using the defined type curve parameters +# Each formation will have a random number of wells (between 10 and 20) +well_header_specs = [ + generate_well_header_data( + spark, + formation, + params["q_i"], + params["d"], + params["b_factor"], + random.randint(10, 20) # Randomly select number of assets per formation + ) + for formation, params in type_curve_dict.items() +] + +# Union all Spark DataFrames for well headers +wells_df = well_header_specs[0] +for spec in well_header_specs[1:]: + wells_df = wells_df.unionByName(spec) +print("WELLS DATAFRAME") +display(wells_df) + +# Convert the well header DataFrame to a dictionary for easy access by column +wells_dict = wells_df.toPandas().to_dict() +print(wells_dict) + +# Generate daily production data for each well using ARPS decline curve parameters +# Loop over each well and create a production profile +daily_prod_specs = [ + generate_daily_production( + spark, + wells_dict["API_NUMBER"][i], # Use API number as unique well identifier + wells_dict["q_i"][i], # Initial production rate (BOPD) + wells_dict["d"][i], # Initial decline rate + wells_dict["b"][i], # ARPS b-factor + 1.0 # Production rate multiplier (set to 1.0 for base case) + ) + for i in range(len(next(iter(wells_dict.values())))) +] + +# Union all Spark DataFrames for daily production +daily_production_df = daily_prod_specs[0] +for spec in daily_prod_specs[1:]: + daily_production_df = daily_production_df.unionByName(spec) +print("DAILY PRODUCTION DATAFRAME") +display(daily_production_df) + +# Generate type curve data for each formation using the defined type curve parameters +# Each formation will have a random number of wells (between 10 and 20) +type_curve_specs = [ + generate_type_curve_forecast( + spark, + formation, + params["q_i"], + params["d"], + params["b_factor"], + ) + for formation, params in type_curve_dict.items() +] + +# Union all Spark DataFrames for type curves +type_curve_df = type_curve_specs[0] +for spec in type_curve_specs[1:]: + type_curve_df = type_curve_df.unionByName(spec) +print("TYPE CURVE DATAFRAME") +display(type_curve_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Section: Visualize Synthetic Production Data +# MAGIC +# MAGIC This section demonstrates how to visualize the generated synthetic production data using Matplotlib and Seaborn. +# MAGIC +# MAGIC You will learn how to plot oil production decline curves for a sample of wells, using standard petroleum engineering units (BOPD) and SPE nomenclature. +# MAGIC +# MAGIC The second chart shows the visual for the forecasted type curves that are created and the basis of daily production values. These tables can be merged together for additional analytical use cases comparing actual to forecasted values. +# MAGIC + +# COMMAND ---------- + +# DBTITLE 1,Daily Production Visualization +# Sample 5 wells and show all rows for those wells +sampled_wells = daily_production_df.select("well_num").dropDuplicates().limit(5) +sampled_df = daily_production_df.join(sampled_wells, on="well_num") + +pdf = sampled_df.select("well_num", "date", "actuals_bopd").toPandas() + +plt.figure(figsize=(10, 6)) +for well_num, group in pdf.groupby("well_num"): + plt.plot(group["date"], group["actuals_bopd"], label=str(well_num), linestyle='-') + +plt.title("Oil Production BOPD (All Days for Sampled Wells)") +plt.xlabel("date") +plt.ylabel("Production Rate (barrels per day)") +plt.legend(title="Well Num") +plt.grid(True) +plt.show() + +# COMMAND ---------- + +# DBTITLE 1,Type Curve Visualization +# Sample 5 wells and show all rows for those wells +sampled_wells = daily_production_df.select("well_num").dropDuplicates().limit(5) +sampled_df = daily_production_df.join(sampled_wells, on="well_num") + +pdf = ( + type_curve_df + .select("formation", "day_from_first_production", "forecast_bopd") + .orderBy("formation", "day_from_first_production") + .toPandas() +) + +plt.figure(figsize=(10, 6)) +for formation, group in pdf.groupby("formation"): + plt.plot(group["day_from_first_production"], group["forecast_bopd"], label=str(formation), linestyle='-') + +plt.title("Forecasted Production BOPD ") +plt.xlabel("Days from first Production") +plt.ylabel("Production Rate (barrels per day)") +plt.legend(title="Formation") +plt.grid(True) +plt.show() From 7a93b01f8efcbcdd17407e8878721260dbde0eed Mon Sep 17 00:00:00 2001 From: ManishNamburi <76540628+ManishNamburi@users.noreply.github.com> Date: Wed, 15 Oct 2025 13:55:19 -0400 Subject: [PATCH 14/20] Example notebook for Gaming industry (#362) * Add files via upload A notebook that generates synthetic log in data for various gamers. The gamers can have multiple devices that are consistent across them and are located in different areas. * Update VideoGameLoginSyntheticDataGeneration.py Fixed code with comments from PR * Update and rename VideoGameLoginSyntheticDataGeneration.py to gaming_data_generation.py --- examples/notebooks/gaming_data_generation.py | 266 +++++++++++++++++++ 1 file changed, 266 insertions(+) create mode 100644 examples/notebooks/gaming_data_generation.py diff --git a/examples/notebooks/gaming_data_generation.py b/examples/notebooks/gaming_data_generation.py new file mode 100644 index 00000000..54e7e712 --- /dev/null +++ b/examples/notebooks/gaming_data_generation.py @@ -0,0 +1,266 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC +# MAGIC # Getting Started with the Databricks Labs Data Generator +# MAGIC This notebook provides an introduction to synthetic data generation using the [Databricks Labs Data Generator (`dbldatagen`)](https://databrickslabs.github.io/dbldatagen/public_docs/index.html). This data generator is useful for generating large synthetic datasets for development, testing, benchmarking, proofs-of-concept, and other use-cases. +# MAGIC +# MAGIC The notebook simulates data for a user login scenario for the gaming industry. + +# COMMAND ---------- + +# DBTITLE 1,Install dbldatagen +# dbldatagen can be installed using pip install commands, as a cluster-scoped library, or as a serverless environment-scoped library. +%pip install dbldatagen + +# COMMAND ---------- + +# DBTITLE 1,Import Modules +import dbldatagen as dg + +from pyspark.sql.types import DoubleType, StringType, TimestampType, LongType +from pyspark.sql.functions import col, expr, sha2, to_date, hour + +# COMMAND ---------- + +# DBTITLE 1,Set up Parameters +# Set up how many rows we want along with how many users, devices and IPs we want +ROW_COUNT = 4500000 +NUMBER_OF_USERS = 200000 +NUMBER_OF_DEVICES = NUMBER_OF_USERS + 50000 +NUMBER_OF_IPS = 40000 + +START_TIMESTAMP = "2025-03-01 00:00:00" +END_TIMESTAMP = "2025-03-30 00:00:00" + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Data Generation Specifications +# MAGIC +# MAGIC Let's start by generating a DataFrame with rows representing unique login information. Data generation is controlled by a `DataGenerator` object. Each `DataGenerator` can be extended with rules specifying the output schema and value generation. Columns can be defined using `withColumn(...)` with a variety of parameters. +# MAGIC +# MAGIC **colName** – Name of column to add. If this conflicts with the underlying seed column (id), it is recommended that the seed column name is customized during the construction of the data generator spec. +# MAGIC +# MAGIC **colType** – Data type for column. This may be specified as either a type from one of the possible pyspark.sql.types (e.g. StringType, DecimalType(10,3) etc) or as a string containing a Spark SQL type definition (i.e String, array, map) +# MAGIC +# MAGIC **omit** – if True, the column will be omitted from the final set of columns in the generated data. Used to create columns that are used by other columns as intermediate results. Defaults to False +# MAGIC +# MAGIC **expr** – Specifies SQL expression used to create column value. If specified, overrides the default rules for creating column value. Defaults to None +# MAGIC +# MAGIC **baseColumn** – String or list of columns to control order of generation of columns. If not specified, column is dependent on base seed column (which defaults to id) + +# COMMAND ---------- + +# DBTITLE 1,Generate a DataFrame +default_annotations_spec = ( + dg.DataGenerator(spark, name="default_annotations_spec", rows=ROW_COUNT) + .withColumn( + "EVENT_TIMESTAMP", + TimestampType(), + data_range=dg.DateRange(START_TIMESTAMP, END_TIMESTAMP, "seconds=1"), + random=True, + ) # Random event timestamp within the specified range + .withColumn( + "internal_ACCOUNTID", + LongType(), + minValue=0x1000000000000, + uniqueValues=NUMBER_OF_USERS, + omit=True, + baseColumnType="hash", + ) # Internal unique account id, omitted from output, used for deterministic hashing + .withColumn( + "ACCOUNTID", StringType(), format="0x%032x", baseColumn="internal_ACCOUNTID" + ) # Public account id as hex string + .withColumn( + "internal_DEVICEID", + LongType(), + minValue=0x1000000000000, + uniqueValues=NUMBER_OF_DEVICES, + omit=True, + baseColumnType="hash", + baseColumn="internal_ACCOUNTID", + ) # Internal device id, based on account, omitted from output + .withColumn( + "DEVICEID", StringType(), format="0x%032x", baseColumn="internal_DEVICEID" + ) # Public device id as hex string + .withColumn("APP_VERSION", StringType(), values=["current"]) # Static app version + .withColumn( + "AUTHMETHOD", StringType(), values=["OAuth", "password"] + ) # Auth method, random selection + # Assign clientName based on DEVICEID deterministically + .withColumn( + "CLIENTNAME", + StringType(), + expr=""" + element_at( + array('SwitchGameClient','XboxGameClient','PlaystationGameClient','PCGameClient'), + (pmod(abs(hash(DEVICEID)), 4) + 1) + ) + """, + ) + .withColumn( + "CLIENTID", + StringType(), + expr="sha2(concat(ACCOUNTID, CLIENTNAME), 256)", + baseColumn=["ACCOUNTID", "CLIENTNAME"], + ) # Deterministic clientId based on ACCOUNTID and clientName + .withColumn( + "SESSION_ID", + StringType(), + expr="sha2(concat(ACCOUNTID, CLIENTID), 256)", + ) # Session correlation id, deterministic hash + .withColumn( + "country", + StringType(), + values=["USA", "UK", "AUS"], + weights=[0.6, 0.2, 0.2], + baseColumn="ACCOUNTID", + random=True, + ) # Assign country with 60% USA, 20% UK, 20% AUS + .withColumn( + "APPENV", StringType(), values=["prod"] + ) # Static environment value + .withColumn( + "EVENT_TYPE", StringType(), values=["account_login_success"] + ) # Static event type + # Assign geoip_city_name based on country and ACCOUNTID + .withColumn( + "CITY", + StringType(), + expr=""" + CASE + WHEN country = 'USA' THEN element_at(array('New York', 'San Francisco', 'Chicago'), pmod(abs(hash(ACCOUNTID)), 3) + 1) + WHEN country = 'UK' THEN 'London' + WHEN country = 'AUS' THEN 'Sydney' + END + """, + baseColumn=["country", "ACCOUNTID"], + ) + .withColumn( + "COUNTRY_CODE2", + StringType(), + expr="CASE WHEN country = 'USA' THEN 'US' WHEN country = 'UK' THEN 'UK' WHEN country = 'AUS' THEN 'AU' END", + baseColumn=["country"], + ) # Country code + # Assign ISP based on country and ACCOUNTID + .withColumn( + "ISP", + StringType(), + expr=""" + CASE + WHEN country = 'USA' THEN element_at(array('Comcast', 'AT&T', 'Verizon', 'Spectrum', 'Cox'), pmod(abs(hash(ACCOUNTID)), 5) + 1) + WHEN country = 'UK' THEN element_at(array('BT', 'Sky', 'Virgin Media', 'TalkTalk', 'EE'), pmod(abs(hash(ACCOUNTID)), 5) + 1) + WHEN country = 'AUS' THEN element_at(array('Telstra', 'Optus', 'TPG', 'Aussie Broadband', 'iiNet'), pmod(abs(hash(ACCOUNTID)), 5) + 1) + ELSE 'Unknown ISP' + END + """, + baseColumn=["country", "ACCOUNTID"], + ) + # Assign latitude based on city + .withColumn( + "LATITUDE", + DoubleType(), + expr=""" + CASE + WHEN CITY = 'New York' THEN 40.7128 + WHEN CITY = 'San Francisco' THEN 37.7749 + WHEN CITY = 'Chicago' THEN 41.8781 + WHEN CITY = 'London' THEN 51.5074 + WHEN CITY = 'Sydney' THEN -33.8688 + ELSE 0.0 + END + """, + baseColumn="CITY", + ) + # Assign longitude based on city + .withColumn( + "LONGITUDE", + DoubleType(), + expr=""" + CASE + WHEN CITY = 'New York' THEN -74.0060 + WHEN CITY = 'San Francisco' THEN -122.4194 + WHEN CITY = 'Chicago' THEN -87.6298 + WHEN CITY = 'London' THEN -0.1278 + WHEN CITY = 'Sydney' THEN 151.2093 + ELSE 0.0 + END + """, + baseColumn="CITY", + ) + # Assign region name based on country and city + .withColumn( + "REGION_NAME", + StringType(), + expr=""" + CASE + WHEN country = 'USA' THEN + CASE + WHEN CITY = 'New York' THEN 'New York' + WHEN CITY = 'San Francisco' THEN 'California' + WHEN CITY = 'Chicago' THEN 'Illinois' + ELSE 'Unknown' + END + WHEN country = 'UK' THEN 'England' + WHEN country = 'AUS' THEN 'New South Wales' + ELSE 'Unknown' + END + """, + baseColumn=["country", "CITY"], + ) + # Internal IP address as integer, unique per device, omitted from output + .withColumn( + "internal_REQUESTIPADDRESS", + LongType(), + minValue=0x1000000000000, + uniqueValues=NUMBER_OF_IPS, + omit=True, + baseColumnType="hash", + baseColumn="internal_DEVICEID", + ) + # Convert internal IP integer to dotted quad string + .withColumn( + "REQUESTIPADDRESS", + StringType(), + expr=""" + concat( + cast((internal_REQUESTIPADDRESS >> 24) & 255 as string), '.', + cast((internal_REQUESTIPADDRESS >> 16) & 255 as string), '.', + cast((internal_REQUESTIPADDRESS >> 8) & 255 as string), '.', + cast(internal_REQUESTIPADDRESS & 255 as string) + ) + """, + baseColumn="internal_REQUESTIPADDRESS", + ) + # Generate user agent string using clientName and SESSION_ID + .withColumn( + "USERAGENT", + StringType(), + expr="concat('Launch/1.0+', CLIENTNAME, '(', CLIENTNAME, '/)/', SESSION_ID)", + baseColumn=["CLIENTNAME", "SESSION_ID"], + ) +) +# Build creates a DataFrame from the DataGenerator +default_logins_df = default_annotations_spec.build() + +# COMMAND ---------- + +# DBTITLE 1,Transform the Dataframe +logins_df = default_logins_df.withColumn( + "EVENT_HOUR", hour(col("EVENT_TIMESTAMP")) +).withColumn("EVENT_DATE", to_date(col("EVENT_TIMESTAMP"))) + +# COMMAND ---------- + +# DBTITLE 1,Look at the Data +display(logins_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Write Data + +# COMMAND ---------- + + +transformed_df.write.mode("overwrite").saveAsTable("main.test.EVENT_ACCOUNT_LOGIN_SUCCESS") From 7cb378906a36d15dd3cf85306e72a160f8357209 Mon Sep 17 00:00:00 2001 From: Greg Hansen <163584195+ghanse@users.noreply.github.com> Date: Mon, 20 Oct 2025 11:52:35 -0400 Subject: [PATCH 15/20] Format `data_analyzer` and `text_generation_plugins` (#370) * Format data_analyzer.py and text_generator_plugins.py * Update formatting and fix docstrings * Update inferred max value for DecimalType columns * Update docstrings and inferred max value for DecimalType columns --- dbldatagen/data_analyzer.py | 601 +++++++++++++++------------ dbldatagen/text_generator_plugins.py | 400 ++++++++++-------- pyproject.toml | 8 - 3 files changed, 563 insertions(+), 446 deletions(-) diff --git a/dbldatagen/data_analyzer.py b/dbldatagen/data_analyzer.py index 195c49b9..5dcaaf92 100644 --- a/dbldatagen/data_analyzer.py +++ b/dbldatagen/data_analyzer.py @@ -8,60 +8,72 @@ This code is experimental and both APIs and code generated is liable to change in future versions. """ import logging +from typing import SupportsFloat, SupportsIndex -import pyspark.sql as ssql -from pyspark.sql.types import LongType, FloatType, IntegerType, StringType, DoubleType, BooleanType, ShortType, \ - TimestampType, DateType, DecimalType, ByteType, BinaryType, StructType, ArrayType, DataType +from pyspark.sql import DataFrame, Row, SparkSession, types -from .spark_singleton import SparkSingleton -from .utils import strip_margins +from dbldatagen.spark_singleton import SparkSingleton +from dbldatagen.utils import strip_margins -SUMMARY_FIELD_NAME = "summary" -SUMMARY_FIELD_NAME_RENAMED = "__summary__" -DATA_SUMMARY_FIELD_NAME = "__data_summary__" + +SUMMARY_FIELD_NAME: str = "summary" +SUMMARY_FIELD_NAME_RENAMED: str = "__summary__" +DATA_SUMMARY_FIELD_NAME: str = "__data_summary__" class DataAnalyzer: - """This class is used to analyze an existing data set to assist in generating a test data set with similar - characteristics, and to generate code from existing schemas and data + """ + This class is used to analyze an existing dataset to assist in generating a test data set with similar data + characteristics. Analyzer results can be used to generate code from existing schemas and data. - :param df: Spark dataframe to analyze - :param sparkSession: Spark session instance to use when performing spark operations - :param debug: If True, additional debug information is logged - :param verbose: If True, additional information is logged + :param df: Spark ``DataFrame`` to analyze + :param sparkSession: ``SparkSession`` to use + :param debug: Whether to log additional debug information (default `False`) + :param verbose: Whether to log detailed execution information (default `False`) .. warning:: Experimental - """ - _DEFAULT_GENERATED_NAME = "synthetic_data" - - _GENERATED_COMMENT = strip_margins(""" - |# Code snippet generated with Databricks Labs Data Generator (`dbldatagen`) DataAnalyzer class - |# Install with `pip install dbldatagen` or in notebook with `%pip install dbldatagen` - |# See the following resources for more details: - |# - |# Getting Started - [https://databrickslabs.github.io/dbldatagen/public_docs/APIDOCS.html] - |# Github project - [https://github.com/databrickslabs/dbldatagen] - |#""", '|') - - _GENERATED_FROM_SCHEMA_COMMENT = strip_margins(""" - |# Column definitions are stubs only - modify to generate correct data - |#""", '|') - - def __init__(self, df=None, sparkSession=None, debug=False, verbose=False): - """ Constructor: - :param df: Dataframe to analyze - :param sparkSession: Spark session to use + debug: bool + verbose: bool + _sparkSession: SparkSession + _df: DataFrame + _dataSummary: dict[str, dict[str, object]] | None + _DEFAULT_GENERATED_NAME: str = "synthetic_data" + _GENERATED_COMMENT: str = strip_margins( + """ + |# Code snippet generated with Databricks Labs Data Generator (`dbldatagen`) DataAnalyzer class + |# Install with `pip install dbldatagen` or in notebook with `%pip install dbldatagen` + |# See the following resources for more details: + |# + |# Getting Started - [https://databrickslabs.github.io/dbldatagen/public_docs/APIDOCS.html] + |# Github project - [https://github.com/databrickslabs/dbldatagen] + |# + """, + marginChar="|" + ) + + _GENERATED_FROM_SCHEMA_COMMENT: str = strip_margins( """ - # set up logging + |# Column definitions are stubs only - modify to generate correct data + |# + """, + marginChar="|" + ) + + def __init__( + self, + df: DataFrame | None = None, + sparkSession: SparkSession | None = None, + debug: bool = False, + verbose: bool = False + ) -> None: self.verbose = verbose self.debug = debug - self._setupLogger() - assert df is not None, "dataframe must be supplied" - + if df is None: + raise ValueError("Argument `df` must be supplied when initializing a `DataAnalyzer`") self._df = df if sparkSession is None: @@ -70,10 +82,10 @@ def __init__(self, df=None, sparkSession=None, debug=False, verbose=False): self._sparkSession = sparkSession self._dataSummary = None - def _setupLogger(self): - """Set up logging - - This will set the logger at warning, info or debug levels depending on the instance construction parameters + def _setupLogger(self) -> None: + """ + Sets up logging for the ``DataAnalyzer``. Configures the logger at warning, info or debug levels depending on + the user-requested behavior. """ self.logger = logging.getLogger("DataAnalyzer") if self.debug: @@ -83,49 +95,64 @@ def _setupLogger(self): else: self.logger.setLevel(logging.WARNING) - def _displayRow(self, row): - """Display details for row""" - results = [] + @staticmethod + def _displayRow(row: Row) -> str: + """ + Displays details for a row as a string. + + :param row: PySpark ``Row`` object to display + :returns: String representing row-level details + """ row_key_pairs = row.asDict() - for x in row_key_pairs: - results.append(f"{x}: {row[x]}") - - return ", ".join(results) - - def _addMeasureToSummary(self, measureName, *, summaryExpr="''", fieldExprs=None, dfData=None, rowLimit=1, - dfSummary=None): - """ Add a measure to the summary dataframe - - :param measureName: Name of measure - :param summaryExpr: Summary expression - :param fieldExprs: list of field expressions (or generator) - :param dfData: Source data df - data being summarized - :param rowLimit: Number of rows to get for measure - :param dfSummary: Summary df - :return: dfSummary with new measure added + return ",".join([f"{x}: {row[x]}" for x in row_key_pairs]) + + @staticmethod + def _addMeasureToSummary( + measureName: str, + *, + summaryExpr: str = "''", + fieldExprs: list[str] | None = None, + dfData: DataFrame | None, + rowLimit: int = 1, + dfSummary: DataFrame | None = None + ) -> DataFrame: + """ + Adds a new measure to the summary ``DataFrame``. + + :param measureName: Measure name + :param summaryExpr: Measure expression as a Spark SQL statement + :param fieldExprs: Optional list of field expressions as Spark SQL Statements + :param dfData: Source ``DataFrame`` to summarize + :param rowLimit: Number of rows to use for ``DataFrame`` summarization + :param dfSummary: Summary metrics ``DataFrame`` + :returns: Summary metrics ``DataFrame`` with the added measure """ - assert dfData is not None, "source data dataframe must be supplied" - assert measureName is not None and len(measureName) > 0, "invalid measure name" + if dfData is None: + raise ValueError("Input DataFrame `dfData` must be supplied when adding measures to a summary") + + if measureName is None: + raise ValueError("Input measure name must be a non-empty string") # add measure name and measure summary - exprs = [f"'{measureName}' as measure_", f"string({summaryExpr}) as summary_"] + expressions = [f"'{measureName}' as measure_", f"string({summaryExpr}) as summary_"] - # add measures for fields - exprs.extend(fieldExprs) + if fieldExprs: + expressions.extend(fieldExprs) if dfSummary is not None: - dfResult = dfSummary.union(dfData.selectExpr(*exprs).limit(rowLimit)) - else: - dfResult = dfData.selectExpr(*exprs).limit(rowLimit) + return dfSummary.union(dfData.selectExpr(*expressions).limit(rowLimit)) - return dfResult + return dfData.selectExpr(*expressions).limit(rowLimit) - def _get_dataframe_describe_stats(self, df): - """ Get summary statistics for dataframe handling renaming of summary field if necessary""" - print("schema", df.schema) + @staticmethod + def _get_dataframe_describe_stats(df: DataFrame) -> DataFrame: + """ + Gets a summary ``DataFrame`` with column-level statistics about the input ``DataFrame``. + :param df: Input ``DataFrame`` + :returns: Summary ``DataFrame`` with column-level statistics + """ src_fields = [fld.name for fld in df.schema.fields] - print("src_fields", src_fields) renamed_summary = False # get summary statistics handling the case where a field named 'summary' exists @@ -145,114 +172,128 @@ def _get_dataframe_describe_stats(self, df): return summary_df - def summarizeToDF(self): - """ Generate summary analysis of data set as dataframe + def summarizeToDF(self) -> DataFrame: + """ + Generates a summary analysis of the input ``DataFrame`` of the ``DataAnalyzer``. - :return: Summary results as dataframe + :returns: Summary ``DataFrame`` with analyzer results + .. note:: The resulting dataframe can be displayed with the ``display`` function in a notebook environment or with the ``show`` method. The output is also used in code generation to generate more accurate code. """ - self._df.cache().createOrReplaceTempView("data_analysis_summary") + self._df.createOrReplaceTempView("data_analysis_summary") total_count = self._df.count() * 1.0 - dtypes = self._df.dtypes - - # schema information - dfDataSummary = self._addMeasureToSummary( - 'schema', + data_summary_df = self._addMeasureToSummary( + measureName="schema", summaryExpr=f"""to_json(named_struct('column_count', {len(dtypes)}))""", fieldExprs=[f"'{dtype[1]}' as {dtype[0]}" for dtype in dtypes], - dfData=self._df) + dfData=self._df + ) - # count - dfDataSummary = self._addMeasureToSummary( - 'count', + data_summary_df = self._addMeasureToSummary( + measureName="count", summaryExpr=f"{total_count}", fieldExprs=[f"string(count({dtype[0]})) as {dtype[0]}" for dtype in dtypes], dfData=self._df, - dfSummary=dfDataSummary) - - dfDataSummary = self._addMeasureToSummary( - 'null_probability', - fieldExprs=[f"""string( round( ({total_count} - count({dtype[0]})) /{total_count}, 2)) as {dtype[0]}""" - for dtype in dtypes], + dfSummary=data_summary_df + ) + + data_summary_df = self._addMeasureToSummary( + measureName="null_probability", + fieldExprs=[ + f"""string( round( ({total_count} - count({dtype[0]})) /{total_count}, 2)) as {dtype[0]}""" + for dtype in dtypes + ], dfData=self._df, - dfSummary=dfDataSummary) + dfSummary=data_summary_df + ) # distinct count - dfDataSummary = self._addMeasureToSummary( - 'distinct_count', + data_summary_df = self._addMeasureToSummary( + measureName="distinct_count", summaryExpr="count(distinct *)", fieldExprs=[f"string(count(distinct {dtype[0]})) as {dtype[0]}" for dtype in dtypes], dfData=self._df, - dfSummary=dfDataSummary) + dfSummary=data_summary_df + ) # min - dfDataSummary = self._addMeasureToSummary( - 'min', + data_summary_df = self._addMeasureToSummary( + measureName="min", fieldExprs=[f"string(min({dtype[0]})) as {dtype[0]}" for dtype in dtypes], dfData=self._df, - dfSummary=dfDataSummary) + dfSummary=data_summary_df + ) - dfDataSummary = self._addMeasureToSummary( - 'max', + data_summary_df = self._addMeasureToSummary( + measureName="max", fieldExprs=[f"string(max({dtype[0]})) as {dtype[0]}" for dtype in dtypes], dfData=self._df, - dfSummary=dfDataSummary) + dfSummary=data_summary_df + ) - descriptionDf = (self._get_dataframe_describe_stats(self._df) - .where(f"{DATA_SUMMARY_FIELD_NAME} in ('mean', 'stddev')")) - describeData = descriptionDf.collect() + description_df = ( + self + ._get_dataframe_describe_stats(self._df) + .where(f"{DATA_SUMMARY_FIELD_NAME} in ('mean', 'stddev')") + ) + description_data = description_df.collect() - for row in describeData: + for row in description_data: measure = row[DATA_SUMMARY_FIELD_NAME] - values = {k[0]: '' for k in dtypes} + values = {k[0]: "" for k in dtypes} row_key_pairs = row.asDict() for k1 in row_key_pairs: values[k1] = str(row[k1]) - dfDataSummary = self._addMeasureToSummary( - measure, + data_summary_df = self._addMeasureToSummary( + measureName=measure, fieldExprs=[f"'{values[dtype[0]]}'" for dtype in dtypes], dfData=self._df, - dfSummary=dfDataSummary) + dfSummary=data_summary_df + ) # string characteristics for strings and string representation of other values - dfDataSummary = self._addMeasureToSummary( - 'print_len_min', + data_summary_df = self._addMeasureToSummary( + measureName="print_len_min", fieldExprs=[f"string(min(length(string({dtype[0]})))) as {dtype[0]}" for dtype in dtypes], dfData=self._df, - dfSummary=dfDataSummary) + dfSummary=data_summary_df + ) - dfDataSummary = self._addMeasureToSummary( - 'print_len_max', + data_summary_df = self._addMeasureToSummary( + measureName="print_len_max", fieldExprs=[f"string(max(length(string({dtype[0]})))) as {dtype[0]}" for dtype in dtypes], dfData=self._df, - dfSummary=dfDataSummary) + dfSummary=data_summary_df + ) - return dfDataSummary + return data_summary_df - def summarize(self, suppressOutput=False): - """ Generate summary analysis of data set and return / print summary results + def summarize(self, suppressOutput: bool = False) -> str: + """ + Generates a summary analysis of the input ``DataFrame`` and returns the analysis as a string. Optionally prints + the summary analysis. - :param suppressOutput: If False, prints results to console also - :return: Summary results as string + :param suppressOutput: Whether to print the summary analysis (default `False`) + :return: Summary analysis as string """ - dfSummary = self.summarizeToDF() + summary_df = self.summarizeToDF() results = [ "Data set summary", "================" ] - for r in dfSummary.collect(): - results.append(self._displayRow(r)) + for row in summary_df.collect(): + results.append(self._displayRow(row)) summary = "\n".join([str(x) for x in results]) @@ -262,213 +303,229 @@ def summarize(self, suppressOutput=False): return summary @classmethod - def _valueFromSummary(cls, dataSummary, colName, measure, defaultValue): - """ Get value from data summary - - :param dataSummary: Data summary to search, optional - :param colName: Column name of column to get value for - :param measure: Measure name of measure to get value for - :param defaultValue: Default value if any other argument is not specified or value could not be found in - data summary - :return: Value from lookup or `defaultValue` if not found + def _valueFromSummary( + cls, + dataSummary: dict[str, dict[str, object]] | None = None, + colName: str | None = None, + measure: str | None = None, + defaultValue: int | float | str | None = None + ) -> object: + """ + Gets a measure value from a data summary given a measure name and column name. Returns a default value when the + measure value cannot be found. + + :param dataSummary: Optional data summary to search (if ``None``, the default value is returned) + :param colName: Optional column name + :param measure: Optional measure name + :param defaultValue: Default return value + :return: Measure value or default value """ - if dataSummary is not None and colName is not None and measure is not None: - if measure in dataSummary: - measureValues = dataSummary[measure] + if dataSummary is None or colName is None or measure is None: + return defaultValue - if colName in measureValues: - return measureValues[colName] + if measure not in dataSummary: + return defaultValue - # return default value if value could not be looked up or found - return defaultValue + measure_values = dataSummary[measure] + if colName not in measure_values: + return defaultValue + + return measure_values[colName] @classmethod - def _generatorDefaultAttributesFromType(cls, sqlType, colName=None, dataSummary=None, sourceDf=None): - """ Generate default set of attributes for each data type + def _generatorDefaultAttributesFromType( + cls, + sqlType: types.DataType, + colName: str | None = None, + dataSummary: dict | None = None + ) -> str: + """ + Generates a Spark SQL expression for the input column and data type. Optionally uses ``DataAnalyzer`` summary + statistics to create Spark SQL expressions for generating data similar to the input ``DataFrame``. - :param sqlType: Instance of `pyspark.sql.types.DataType` - :param colName: Name of column being generated - :param dataSummary: Map of maps of attributes from data summary, optional - :param sourceDf: Source dataframe to retrieve attributes of real data, optional - :return: Attribute string for supplied sqlType + :param sqlType: Data type as an instance of ``pyspark.sql.types.DataType`` + :param colName: Column name + :param dataSummary: Optional map of maps of attributes from the data summary + :return: Spark SQL expression for supplied column and data type - When generating code from a schema, we have no data heuristics to determine how data should be generated, - so goal is to just generate code that produces some data. + .. note:: + When generating expressions from a schema, no data heuristics are available to determine how data should be + generated. This method will use default values according to Spark's data type limits to generate working + expressions for data generation. Users are expected to modify the generated code to their needs. """ - assert isinstance(sqlType, DataType) + if not isinstance(sqlType, types.DataType): + raise ValueError( + f"Argument 'sqlType' with type {type(sqlType)} must be an instance of `pyspark.sql.types.DataType`" + ) - if sqlType == StringType(): + if sqlType == types.StringType(): result = """template=r'\\\\w'""" - elif sqlType in [IntegerType(), LongType()]: - minValue = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0) - maxValue = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=1000000) - result = f"""minValue={minValue}, maxValue={maxValue}""" - elif sqlType == ByteType(): - minValue = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0) - maxValue = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=127) - result = f"""minValue={minValue}, maxValue={maxValue}""" - elif sqlType == ShortType(): - minValue = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0) - maxValue = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=32767) - result = f"""minValue={minValue}, maxValue={maxValue}""" - elif sqlType == BooleanType(): + + elif sqlType in [types.IntegerType(), types.LongType()]: + min_value = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0) + max_value = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=1000000) + result = f"""minValue={min_value}, maxValue={max_value}""" + + elif sqlType == types.ByteType(): + min_value = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0) + max_value = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=127) + result = f"""minValue={min_value}, maxValue={max_value}""" + + elif sqlType == types.ShortType(): + min_value = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0) + max_value = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=32767) + result = f"""minValue={min_value}, maxValue={max_value}""" + + elif sqlType == types.BooleanType(): result = """expr='id % 2 = 1'""" - elif sqlType == DateType(): + + elif sqlType == types.DateType(): result = """expr='current_date()'""" - elif isinstance(sqlType, DecimalType): - minValue = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0) - maxValue = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=1000) - result = f"""minValue={minValue}, maxValue={maxValue}""" - elif sqlType in [FloatType(), DoubleType()]: - minValue = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0.0) - maxValue = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=1000000.0) - result = f"""minValue={minValue}, maxValue={maxValue}, step=0.1""" - elif sqlType == TimestampType(): + + elif sqlType == types.DecimalType(): + min_value = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0) + max_value = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=1000000.0) + result = f"""minValue={min_value}, maxValue={max_value}""" + + elif sqlType in [types.FloatType(), types.DoubleType()]: + min_value = cls._valueFromSummary(dataSummary, colName, "min", defaultValue=0.0) + max_value = cls._valueFromSummary(dataSummary, colName, "max", defaultValue=1000000.0) + result = f"""minValue={min_value}, maxValue={max_value}, step=0.1""" + + elif sqlType == types.TimestampType(): result = """begin="2020-01-01 01:00:00", end="2020-12-31 23:59:00", interval="1 minute" """ - elif sqlType == BinaryType(): + + elif sqlType == types.BinaryType(): result = """expr="cast('dbldatagen generated synthetic data' as binary)" """ + else: result = """expr='null'""" - percentNullsValue = float(cls._valueFromSummary(dataSummary, colName, "null_probability", defaultValue=0.0)) + summary_value = cls._valueFromSummary(dataSummary, colName, "null_probability", defaultValue=0.0) + percent_nulls_value = ( + float(summary_value) if isinstance(summary_value, str | SupportsFloat | SupportsIndex) else 0.0 + ) - if percentNullsValue > 0.0: - result = result + f", percentNulls={percentNullsValue}" + if percent_nulls_value > 0.0: + result = result + f", percentNulls={percent_nulls_value}" return result @classmethod - def _scriptDataGeneratorCode(cls, schema, *, dataSummary=None, sourceDf=None, suppressOutput=False, name=None): + def _scriptDataGeneratorCode( + cls, + schema: types.StructType, + *, + dataSummary: dict | None = None, + sourceDf: DataFrame | None = None, + suppressOutput: bool = False, + name: str | None = None + ) -> str: """ - Generate outline data generator code from an existing dataframe - - This will generate a data generator spec from an existing dataframe. The resulting code - can be used to generate a data generation specification. + Generates code to build a ``DataGenerator`` from an existing dataframe. Analyzes the dataframe passed to the + constructor of the ``DataAnalyzer`` and returns a script for generating similar data. - Note at this point in time, the code generated is stub code only. - For most uses, it will require further modification - however it provides a starting point - for generation of the specification for a given data set. - - The dataframe to be analyzed is the dataframe passed to the constructor of the DataAnalyzer object. - - :param schema: Pyspark schema - i.e manually constructed StructType or return value from `dataframe.schema` - :param dataSummary: Map of maps of attributes from data summary, optional - :param sourceDf: Source dataframe to retrieve attributes of real data, optional - :param suppressOutput: Suppress printing of generated code if True + :param schema: Pyspark schema as a ``StructType`` + :param dataSummary: Optional map of maps of attributes from the data summary + :param sourceDf: Optional ``DataFrame`` to retrieve attributes from existing data + :param suppressOutput: Whether to suppress printing attributes during execution (default `False`) :param name: Optional name for data generator - :return: String containing skeleton code + :return: Data generation code string + .. note:: + Code generated by this method should be treated as experimental. For most uses, generated code requires further + modification. Results are intended to provide an initial script for generating data from the input dataset. """ - assert isinstance(schema, StructType), "expecting valid Pyspark Schema" - - stmts = [] + statements = [] if name is None: name = cls._DEFAULT_GENERATED_NAME - stmts.append(cls._GENERATED_COMMENT) - - stmts.append("import dbldatagen as dg") - stmts.append("import pyspark.sql.types") - - stmts.append(cls._GENERATED_FROM_SCHEMA_COMMENT) - - stmts.append(strip_margins( - f"""generation_spec = ( - | dg.DataGenerator(sparkSession=spark, - | name='{name}', + statements.append(cls._GENERATED_COMMENT) + statements.append("import dbldatagen as dg") + statements.append("import pyspark.sql.types") + statements.append(cls._GENERATED_FROM_SCHEMA_COMMENT) + statements.append( + strip_margins( + f"""generation_spec = ( + | dg.DataGenerator(sparkSession=spark, + | name='{name}', | rows=100000, | random=True, | )""", - '|')) + marginChar="|" + ) + ) indent = " " - for fld in schema.fields: - col_name = fld.name - col_type = fld.dataType.simpleString() - - if isinstance(fld.dataType, ArrayType): - col_type = fld.dataType.elementType.simpleString() - field_attributes = cls._generatorDefaultAttributesFromType(fld.dataType.elementType) # no data look up - array_attributes = """structType='array', numFeatures=(2,6)""" - name_and_type = f"""'{col_name}', '{col_type}'""" - stmts.append(indent + f""".withColumn({name_and_type}, {field_attributes}, {array_attributes})""") + for field in schema.fields: + column_name = field.name + column_type = field.dataType.simpleString() + + if isinstance(field.dataType, types.ArrayType): + column_type = field.dataType.elementType.simpleString() + field_attributes = cls._generatorDefaultAttributesFromType(field.dataType.elementType) + array_attributes = "structType='array', numFeatures=(2,6)" + name_and_type = f"'{column_name}', '{column_type}'" + statements.append(indent + f".withColumn({name_and_type}, {field_attributes}, {array_attributes})") else: - field_attributes = cls._generatorDefaultAttributesFromType(fld.dataType, - colName=col_name, - dataSummary=dataSummary, - sourceDf=sourceDf) - stmts.append(indent + f""".withColumn('{col_name}', '{col_type}', {field_attributes})""") - stmts.append(indent + ")") + field_attributes = cls._generatorDefaultAttributesFromType( + field.dataType, colName=column_name, dataSummary=dataSummary + ) + statements.append(indent + f".withColumn('{column_name}', '{column_type}', {field_attributes})") + statements.append(indent + ")") if not suppressOutput: - for line in stmts: + for line in statements: print(line) - return "\n".join(stmts) + return "\n".join(statements) @classmethod - def scriptDataGeneratorFromSchema(cls, schema, suppressOutput=False, name=None): + def scriptDataGeneratorFromSchema( + cls, schema: types.StructType, suppressOutput: bool = False, name: str | None = None + ) -> str: """ - Generate outline data generator code from an existing dataframe - - This will generate a data generator spec from an existing dataframe. The resulting code - can be used to generate a data generation specification. - - Note at this point in time, the code generated is stub code only. - For most uses, it will require further modification - however it provides a starting point - for generation of the specification for a given data set. - - The dataframe to be analyzed is the dataframe passed to the constructor of the DataAnalyzer object. + Generates code to build a ``DataGenerator`` from an existing dataframe schema. Analyzes the schema of the + ``DataFrame`` passed to the ``DataAnalyzer`` and returns a script for generating similar data. - :param schema: Pyspark schema - i.e manually constructed StructType or return value from `dataframe.schema` - :param suppressOutput: Suppress printing of generated code if True + :param schema: Pyspark schema as a ``StructType`` + :param suppressOutput: Whether to suppress printing attributes during execution (default `False`) :param name: Optional name for data generator - :return: String containing skeleton code + :return: Data generation code string + .. note:: + Code generated by this method should be treated as experimental. For most uses, generated code requires further + modification. Results are intended to provide an initial script for generating data from the input dataset. """ - return cls._scriptDataGeneratorCode(schema, - suppressOutput=suppressOutput, - name=name) + return cls._scriptDataGeneratorCode(schema, suppressOutput=suppressOutput, name=name) - def scriptDataGeneratorFromData(self, suppressOutput=False, name=None): + def scriptDataGeneratorFromData(self, suppressOutput: bool = False, name: str | None = None) -> str: """ - Generate outline data generator code from an existing dataframe + Generates code to build a ``DataGenerator`` from an existing dataframe. Analyzes statistical properties of the + ``DataFrame`` passed to the ``DataAnalyzer`` and returns a script for generating similar data. - This will generate a data generator spec from an existing dataframe. The resulting code - can be used to generate a data generation specification. - - Note at this point in time, the code generated is stub code only. - For most uses, it will require further modification - however it provides a starting point - for generation of the specification for a given data set - - The dataframe to be analyzed is the Spark dataframe passed to the constructor of the DataAnalyzer object - - :param suppressOutput: Suppress printing of generated code if True + :param suppressOutput: Whether to suppress printing attributes during execution (default `False`) :param name: Optional name for data generator - :return: String containing skeleton code + :return: Data generation code string + .. note:: + Code generated by this method should be treated as experimental. For most uses, generated code requires further + modification. Results are intended to provide an initial script for generating data from the input dataset. """ - assert self._df is not None - - if not isinstance(self._df, ssql.DataFrame): - self.logger.warning(strip_margins( - """The parameter `sourceDf` should be a valid Pyspark dataframe. - |Note this warning may false due to use of remote connection to a Spark cluster""", - '|')) + if not self._df: + raise ValueError("Missing `DataAnalyzer` property `df` for scripting a data generator from data") if self._dataSummary is None: df_summary = self.summarizeToDF() - self._dataSummary = {} + for row in df_summary.collect(): row_key_pairs = row.asDict() - self._dataSummary[row['measure_']] = row_key_pairs + self._dataSummary[row["measure_"]] = row_key_pairs - return self._scriptDataGeneratorCode(self._df.schema, - suppressOutput=suppressOutput, - name=name, - dataSummary=self._dataSummary, - sourceDf=self._df) + return self._scriptDataGeneratorCode( + self._df.schema, suppressOutput=suppressOutput, name=name, dataSummary=self._dataSummary, sourceDf=self._df + ) diff --git a/dbldatagen/text_generator_plugins.py b/dbldatagen/text_generator_plugins.py index 135d50eb..11cf8c61 100644 --- a/dbldatagen/text_generator_plugins.py +++ b/dbldatagen/text_generator_plugins.py @@ -8,19 +8,48 @@ import importlib import logging +from collections.abc import Callable +from types import ModuleType +from typing import Optional -from .text_generators import TextGenerator -from .utils import DataGenError +import pandas as pd + +from dbldatagen.text_generators import TextGenerator +from dbldatagen.utils import DataGenError + + +class _FnCallContext: + """ + Inner class for storing context between function calls. + + initial instances of random number generators, clients for services etc here during execution + of the `initFn` calls + + :param txtGen: - reference to outer PyfnText object + """ + textGenerator: "TextGenerator" + + def __init__(self, txtGen: "TextGenerator") -> None: + self.textGenerator = txtGen + + def __setattr__(self, name: str, value: object) -> None: + """Allow dynamic attribute setting for plugin context.""" + super().__setattr__(name, value) + + def __getattr__(self, name: str) -> object: + """Allow dynamic attribute access for plugin context.""" + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") class PyfuncText(TextGenerator): # lgtm [py/missing-equals] - """ Text generator that supports generating text from arbitrary Python function + """ + Text generator that supports generating text from an arbitrary Python function. - :param fn: function to call to generate text. - :param init: function to call to initialize context - :param initPerBatch: if init per batch is set to True, initialization of context is performed on every Pandas udf - call. Default is False. - :param name: String representing name of text generator when converted to string via ``repr`` or ``str`` + :param fn: Python function which generates text + :param init: Python function which creates an initial context/state + :param initPerBatch: Whether to call the initialization function for each invocation of the Pandas UDF which + generates text (default `false`) + :param name: Optional name of the text generator when converted to string via ``repr`` or ``str`` The two functions define the plugin model @@ -39,70 +68,75 @@ class PyfuncText(TextGenerator): # lgtm [py/missing-equals] enclosing text generator. .. note:: - There are no expectations of repeatability of data generation when using external code - or external libraries to generate text. + There are no expectations of repeatability of data generation when using external code + or external libraries to generate text. - However, custom code can call the base class method to get a Numpy random - number generator instance. This will have been seeded using the ``dbldatagen`` - random number seed if one was specified, so random numbers generated from this will be repeatable. + However, custom code can call the base class method to get a Numpy random + number generator instance. This will have been seeded using the ``dbldatagen`` + random number seed if one was specified, so random numbers generated from this will be repeatable. - The custom code may call the property ``randomSeed`` on the text generator object to get the random seed - which may be used to seed library specific initialization. + The custom code may call the property ``randomSeed`` on the text generator object to get the random seed + which may be used to seed library specific initialization. - This random seed property may have the values ``None`` or ``-1`` which should be treated as meaning dont - use a random seed. + This random seed property may have the values ``None`` or ``-1`` which should be treated as meaning dont + use a random seed. - The code does not guarantee thread or cross process safety. If a new instance of the random number - generator is needed, you may call the base class method with the argument `forceNewInstance` set to True. + The code does not guarantee thread or cross process safety. If a new instance of the random number + generator is needed, you may call the base class method with the argument `forceNewInstance` set to True. """ - - class _FnCallContext: - """ inner class to support storage of context between calls - - initial instances of random number generators, clients for services etc here during execution - of the `initFn` calls - - :param txtGen: - reference to outer PyfnText object - - """ - - def __init__(self, txtGen): - self.textGenerator = txtGen - - def __init__(self, fn, *, init=None, initPerBatch=False, name=None, rootProperty=None): + _name: str + _initPerBatch: bool + _rootProperty: object + _pyFn: Callable + _initFn: Callable | None + _context: _FnCallContext | None + + def __init__( + self, + fn: Callable, + *, + init: Callable | None = None, + initPerBatch: bool = False, + name: str | None = None, + rootProperty: object = None + ) -> None: super().__init__() - assert fn is not None or callable(fn), "Function must be provided wiith signature fn(context, oldValue)" - assert init is None or callable(init), "Init function must be a callable function or lambda if passed" + if not callable(fn): + raise ValueError("Function must be provided with signature fn(context, oldValue)") + + if init and not callable(init): + raise ValueError("Init function must be a callable function or lambda if passed") # if root property is provided, root property will be passed to generate text function self._rootProperty = rootProperty - self._pyFn = fn # generate text function self._initFn = init # context initialization function self._context = None # context used to hold library root object and other properties # if init per batch is True, initialization of context will be per UDF call - assert initPerBatch in [True, False], "initPerBatch must evaluate to boolean True or False" - self._initPerBatch = initPerBatch + if not isinstance(initPerBatch, bool): + raise ValueError("initPerBatch must evaluate to boolean True or False") + self._initPerBatch = initPerBatch self._name = name if name is not None else "PyfuncText" - def __str__(self): - """ Get string representation of object - ``name`` property is used to provide user friendly name for text generator + def __str__(self) -> str: """ - return f"{self._name}({repr(self._pyFn)}, init={self._initFn})" + Gets a string representation of the text generator using the ``name`` property. - def _getContext(self, forceNewInstance=False): - """ Get the context for plugin function calls + :returns: String representation of the text generator + """ + return f"{self._name}({self._pyFn!r}, init={self._initFn})" - :param forceNewInstance: if True, forces each call to create a new context - :return: existing or newly created context. + def _getContext(self, forceNewInstance: bool = False) -> _FnCallContext: + """ + Gets the context for plugin function calls. + :param forceNewInstance: Whether to create a new context for each call (default `False`) + :return: Existing or new context for plugin function calls """ - context = self._context - if context is None or forceNewInstance: - context = PyfuncText._FnCallContext(self) + if self._context is None or forceNewInstance: + context = _FnCallContext(self) # init context using context creator if any provided if self._initFn is not None: @@ -113,41 +147,42 @@ def _getContext(self, forceNewInstance=False): self._context = context else: return context - return self._context - def pandasGenerateText(self, v): - """ Called to generate text via Pandas UDF mechanism + return self._context - :param v: base value of column as Pandas Series + def pandasGenerateText(self, v: pd.Series) -> pd.Series: + """ + Generates text from input columns using a Pandas UDF. + :param v: Input column values as Pandas Series + :returns: Generated text values as a Pandas Series or DataFrame """ # save object properties in local vars to avoid overhead of object dereferences # on every call context = self._getContext(self._initPerBatch) evalFn = self._pyFn - rootProperty = getattr(context, self._rootProperty) if self._rootProperty is not None else None + rootProperty = getattr(context, str(self._rootProperty), None) if self._rootProperty else None # define functions to call with context and with root property - def _valueFromFn(originalValue): + def _valueFromFn(originalValue: object) -> object: return evalFn(context, originalValue) - def _valueFromFnWithRoot(originalValue): + def _valueFromFnWithRoot(_: object) -> object: return evalFn(rootProperty) if rootProperty is not None: - results = v.apply(_valueFromFnWithRoot, args=None) - else: - results = v.apply(_valueFromFn, args=None) + return v.apply(_valueFromFnWithRoot) - return results + return v.apply(_valueFromFn) class PyfuncTextFactory: - """PyfuncTextFactory applies syntactic wrapping around creation of PyfuncText objects + """ + Applies syntactic wrapping around the creation of PyfuncText objects. - :param name: name of generated object (when converted to string via ``str``) + :param name: Generated object name (when converted to string via ``str``) - It allows the use of the following constructs: + This class allows the use of the following constructs: .. code-block:: python @@ -180,87 +215,100 @@ def initFaker(ctx): init=initFaker, rootProperty="faker", name="FakerText")) - """ + _name: str + _initPerBatch: bool + _initFn: Callable | None + _rootProperty: object | None - def __init__(self, name=None): - """ - - :param name: name of generated object (when converted to string via ``str``) - - """ + def __init__(self, name: str | None = None) -> None: self._initFn = None self._rootProperty = None self._name = "PyfuncText" if name is None else name self._initPerBatch = False - def withInit(self, fn): - """ Specifies context initialization function + def withInit(self, fn: Callable) -> "PyfuncTextFactory": + """ + Sets the initialization function for creating context. - :param fn: function pointer or lambda function for initialization - signature should ``initFunction(context)`` + :param fn: Callable function for initializing context; Signature should ``initFunction(context)`` + :returns: Modified text generation factory with the specified initialization function - .. note:: - This variation initializes the context once per worker process per text generator - instance. + .. note:: + This variation initializes the context once per worker process per text generator + instance. """ self._initFn = fn return self - def withInitPerBatch(self, fn): - """ Specifies context initialization function + def withInitPerBatch(self, fn: Callable) -> "PyfuncTextFactory": + """ + Sets the initialization function for creating context for each batch. - :param fn: function pointer or lambda function for initialization - signature should ``initFunction(context)`` + :param fn: Callable function for initializing context; Signature should ``initFunction(context)`` + :returns: Modified text generation factory with the specified initialization function called for each batch - .. note:: - This variation initializes the context once per internal pandas UDF call. - The UDF call will be called once per 10,000 rows if system is configured using defaults. - Setting the pandas batch size as an argument to the DataSpec creation will change the default - batch size. + .. note:: + This variation initializes the context once per internal pandas UDF call. + The UDF call will be called once per 10,000 rows if system is configured using defaults. + Setting the pandas batch size as an argument to the DataSpec creation will change the default + batch size. """ self._initPerBatch = True return self.withInit(fn) - def withRootProperty(self, prop): - """ If called, specifies the property of the context to be passed to the text generation function. - If not called, the context object itself will be passed to the text generation function. + def withRootProperty(self, prop: object) -> "PyfuncTextFactory": + """ + Sets the context property to be passed to the text generation function. If not called, the context object will + be passed to the text generation function. + + :param prop: Context property + :returns: Modified text generation factory with the context property """ self._rootProperty = prop return self - def __call__(self, evalFn, *args, isProperty=False, **kwargs): - """ Internal function call mechanism that implements the syntax expansion + def __call__( + self, + evalFn: str | Callable, + *args, + isProperty: bool = False, + **kwargs + ) -> PyfuncText: + """ + Internal function calling mechanism that implements the syntax expansion. - :param evalFn: text generation function or lambda - :param args: optional args to be passed by position - :param kwargs: optional keyword args following Python keyword passing mechanism - :param isProperty: if true, interpret evalFn as string name of property, not a function or method + :param evalFn: Callable text generation function + :param args: Optional arguments to pass by position to the text generation function + :param kwargs: Optional keyword arguments following Python keyword passing mechanism + :param isProperty: Whether to interpret the evaluation function as string name of property instead of a callable + function (default `False`) """ assert evalFn is not None and (type(evalFn) is str or callable(evalFn)), "Function must be provided" - if type(evalFn) is str: - assert self._rootProperty is not None and len(self._rootProperty.strip()) > 0, \ - "string named functions can only be used on text generators with root property" - fnName = evalFn - if len(args) > 0 and len(kwargs) > 0: - # generate lambda with both kwargs and args - assert not isProperty, "isProperty cannot be true if using arguments" - evalFn = lambda root: getattr(root, fnName)(*args, **kwargs) - elif len(args) > 0: - # generate lambda with positional args - assert not isProperty, "isProperty cannot be true if using arguments" - evalFn = lambda root: getattr(root, fnName)(*args) - elif len(kwargs) > 0: - # generate lambda with keyword args - assert not isProperty, "isProperty cannot be true if using arguments" - evalFn = lambda root: getattr(root, fnName)(**kwargs) - elif isProperty: - # generate lambda with property access, not method call - evalFn = lambda root: getattr(root, fnName) - else: - # generate lambda with no args - evalFn = (lambda root: getattr(root, fnName)()) + if isinstance(evalFn, str): + if not self._rootProperty: + raise ValueError("String named functions can only be used on text generators with root property") + function_name = evalFn + + if (len(args) > 0 or len(kwargs) > 0) and isProperty: + raise ValueError("Argument 'isProperty' cannot be used when passing arguments") + + def generated_evalFn(root: object) -> object: + method = getattr(root, function_name) + + if isProperty: + return method + elif len(args) > 0 and len(kwargs) > 0: + return method(*args, **kwargs) + elif len(args) > 0: + return method(*args) + elif len(kwargs) > 0: + return method(**kwargs) + else: + return method() + + evalFn = generated_evalFn # returns the actual PyfuncText text generator object. # Note all syntax expansion is performed once only @@ -268,24 +316,31 @@ def __call__(self, evalFn, *args, isProperty=False, **kwargs): class FakerTextFactory(PyfuncTextFactory): - """ Factory object for Faker text generator flavored ``PyfuncText`` objects + """ + Factory for creating Faker text generators. - :param locale: list of locales. If empty, defaults to ``en-US`` - :param providers: list of providers - :param name: name of generated objects. Defaults to ``FakerText`` - :param lib: library import name of Faker library. If none passed, uses ``faker`` - :param rootClass: name of root object class If none passed, uses ``Faker`` + :param locale: Optional list of locales (default is ``["en-US"]``) + :param providers: List of providers + :param name: Optional name of generated objects (default is ``FakerText``) + :param lib: Optional import alias of Faker library (dfault is ``"faker"``) + :param rootClass: Optional name of the root object class (default is ``"Faker"``) ..note :: Both the library name and root object class can be overridden - this is primarily for internal testing purposes. """ - _FAKER_LIB = "faker" + _defaultFakerTextFactory: Optional["FakerTextFactory"] = None + _FAKER_LIB: str = "faker" - _defaultFakerTextFactory = None - - def __init__(self, *, locale=None, providers=None, name="FakerText", lib=None, - rootClass=None): + def __init__( + self, + *, + locale: str | list[str] | None = None, + providers: list | None = None, + name: str = "FakerText", + lib: str | None = None, + rootClass: str | None = None + ) -> None: super().__init__(name) @@ -304,37 +359,42 @@ def __init__(self, *, locale=None, providers=None, name="FakerText", lib=None, self._rootObjectClass = rootClass # load the library - fakerModule = self._loadLibrary(lib) + faker_module = self._loadLibrary(lib) # make the initialization function - initFn = self._mkInitFn(fakerModule, locale, providers) + init_function = self._mkInitFn(faker_module, locale, providers) - self.withInit(initFn) + self.withInit(init_function) self.withRootProperty("faker") @classmethod - def _getDefaultFactory(cls, lib=None, rootClass=None): - """Class method to get default faker text factory + def _getDefaultFactory(cls, lib: str | None = None, rootClass: str | None = None) -> "FakerTextFactory": + """ + Gets a default faker text factory. - Not intended for general use + :param lib: Optional import alias of Faker library (dfault is ``"faker"``) + :param rootClass: Optional name of the root object class (default is ``"Faker"``) """ if cls._defaultFakerTextFactory is None: cls._defaultFakerTextFactory = FakerTextFactory(lib=lib, rootClass=rootClass) return cls._defaultFakerTextFactory - def _mkInitFn(self, libModule, locale, providers): - """ Make Faker initialization function + def _mkInitFn(self, libModule: object, locale: str | list[str] | None, providers: list | None) -> Callable: + """ + Creates a Faker initialization function. - :param locale: locale string or list of locale strings - :param providers: providers to load - :return: + :param libModule: Faker module + :param locale: Locale string or list of locale strings (e.g. "en-us") + :param providers: List of Faker providers to load + :returns: Callable initialization function """ - assert libModule is not None, "must have a valid loaded Faker library module" + if libModule is None: + raise ValueError("must have a valid loaded Faker library module") fakerClass = getattr(libModule, self._rootObjectClass) # define the initialization function for Faker - def fakerInitFn(ctx): + def fakerInitFn(ctx: _FnCallContext) -> None: if locale is not None: ctx.faker = fakerClass(locale=locale) else: @@ -342,44 +402,52 @@ def fakerInitFn(ctx): if providers is not None: for provider in providers: - ctx.faker.add_provider(provider) + ctx.faker.add_provider(provider) # type: ignore[attr-defined] return fakerInitFn - def _loadLibrary(self, lib): - """ Load faker library if not already loaded + def _loadLibrary(self, lib: str) -> ModuleType: + """ + Loads the faker library. - :param lib: library name of Faker library. If none passed, uses ``faker`` + :param lib: Optional alias name for Faker library (default is ``"faker"``) """ - # load library try: if lib is not None: - assert type(lib) is str and len(lib.strip()), f"Library ``{lib}`` must be a valid library name" + if not isinstance(lib, str): + raise ValueError(f"Input Faker alias with type '{type(lib)}' must be of type 'str'") + + if not lib: + raise ValueError("Input Faker alias must be provided") if lib in globals(): - return globals()[lib] + module = globals()[lib] + if isinstance(module, ModuleType): + return module + else: + raise ValueError(f"Global '{lib}' is not a module") + else: fakerModule = importlib.import_module(lib) globals()[lib] = fakerModule return fakerModule - except RuntimeError as err: - # pylint: disable=raise-missing-from - raise DataGenError("Could not load or initialize Faker library", err) - + else: + raise ValueError("Library name must be provided") -def fakerText(mname, *args, _lib=None, _rootClass=None, **kwargs): - """Generate faker text generator object using default FakerTextFactory - instance + except RuntimeError as err: + raise DataGenError("Could not load or initialize Faker library") from err - :param mname: method name to invoke - :param args: positional args to be passed to underlying Faker instance - :param _lib: internal only param - library to load - :param _rootClass: internal only param - root class to create - - :returns : instance of PyfuncText for use with Faker - ``fakerText("sentence")`` is same as ``FakerTextFactory()("sentence")`` +def fakerText(mname: str, *args, _lib: str | None = None, _rootClass: str | None = None, **kwargs) -> PyfuncText: + """ + Creates a faker text generator object using the default ``FakerTextFactory`` instance. Calling this method is + equivalent to calling ``FakerTextFactory()("sentence")``. + + :param mname: Method name to invoke + :param args: Positional argumentss to pass to the Faker text generation method + :param _lib: Optional import alias of Faker library (default is ``"faker"``) + :param _rootClass: Optional name of the root object class (default is ``"Faker"``) + :returns : ``PyfuncText`` for use with Faker """ - defaultFactory = FakerTextFactory._getDefaultFactory(lib=_lib, - rootClass=_rootClass) - return defaultFactory(mname, *args, **kwargs) # pylint: disable=not-callable + default_factory = FakerTextFactory._getDefaultFactory(lib=_lib, rootClass=_rootClass) + return default_factory(mname, *args, **kwargs) # pylint: disable=not-callable diff --git a/pyproject.toml b/pyproject.toml index edffc15e..2ec12de9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,13 +156,11 @@ exclude = [ "dbldatagen/__init__.py", "dbldatagen/column_generation_spec.py", "dbldatagen/column_spec_options.py", - "dbldatagen/data_analyzer.py", "dbldatagen/datagen_constants.py", "dbldatagen/datarange.py", "dbldatagen/daterange.py", "dbldatagen/nrange.py", "dbldatagen/schema_parser.py", - "dbldatagen/text_generator_plugins.py", ] [tool.ruff.lint] @@ -237,14 +235,12 @@ ignore = [ "dbldatagen/__init__.py", "dbldatagen/column_generation_spec.py", "dbldatagen/column_spec_options.py", - "dbldatagen/data_analyzer.py", "dbldatagen/datagen_constants.py", "dbldatagen/datarange.py", "dbldatagen/daterange.py", "dbldatagen/nrange.py", "dbldatagen/schema_parser.py", "dbldatagen/serialization.py", - "dbldatagen/text_generator_plugins.py", "dbldatagen/utils.py" ] @@ -272,7 +268,6 @@ ignore-paths = [ "dbldatagen/__init__.py", "dbldatagen/column_generation_spec.py", "dbldatagen/column_spec_options.py", - "dbldatagen/data_analyzer.py", "dbldatagen/data_generator.py", "dbldatagen/datagen_constants.py", "dbldatagen/datarange.py", @@ -280,7 +275,6 @@ ignore-paths = [ "dbldatagen/nrange.py", "dbldatagen/schema_parser.py", "dbldatagen/serialization.py", - "dbldatagen/text_generator_plugins.py", "dbldatagen/utils.py" ] @@ -403,14 +397,12 @@ exclude = [ "dbldatagen/__init__.py", "dbldatagen/column_generation_spec.py", "dbldatagen/column_spec_options.py", - "dbldatagen/data_analyzer.py", "dbldatagen/datagen_constants.py", "dbldatagen/datarange.py", "dbldatagen/daterange.py", "dbldatagen/nrange.py", "dbldatagen/schema_parser.py", "dbldatagen/serialization.py", - "dbldatagen/text_generator_plugins.py", "dbldatagen/utils.py" ] warn_return_any = true From 4eb173d6fd6db604949d001009cd2d63730436b0 Mon Sep 17 00:00:00 2001 From: Adamdion <65203526+Adamdion@users.noreply.github.com> Date: Tue, 21 Oct 2025 15:28:13 -0500 Subject: [PATCH 16/20] Example notebook for Retail and Consumer Packaged Goods industry (#366) * Added CPG supply chain dbldatagen notebook * Deleting original CPG_supply_chain_datagen.py notebook * Adding new file with the changes Adding new file in the correct location in the repo. * Cleaned up comments --------- Co-authored-by: Greg Hansen <163584195+ghanse@users.noreply.github.com> --- examples/notebooks/retail_data_generation.py | 692 +++++++++++++++++++ 1 file changed, 692 insertions(+) create mode 100644 examples/notebooks/retail_data_generation.py diff --git a/examples/notebooks/retail_data_generation.py b/examples/notebooks/retail_data_generation.py new file mode 100644 index 00000000..62b592de --- /dev/null +++ b/examples/notebooks/retail_data_generation.py @@ -0,0 +1,692 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # CPG Supply Chain Dummy Data Generator +# MAGIC +# MAGIC ## Educational Guide to dbldatagen +# MAGIC +# MAGIC This notebook demonstrates how to use [**dbldatagen**](https://databrickslabs.github.io/dbldatagen/public_docs/index.html) to simulate data from a supply chain for consumer packaged goods (CPG). +# MAGIC +# MAGIC +# MAGIC ### Datasets We'll Create: +# MAGIC 1. **Products** - SKU master data with categories and pricing +# MAGIC 2. **Distribution Centers** - Network locations with capacity +# MAGIC 3. **Retail Stores** - Customer-facing locations +# MAGIC 4. **Orders** - Manufacturing execution data +# MAGIC 5. **Inventory Snapshots** - Multi-echelon inventory with risk metrics +# MAGIC 6. **Shipments** - Transportation and logistics data + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Installation & Setup +# MAGIC +# MAGIC dbldatagen can be installed using pip install commands, as a cluster-scoped library, or as a serverless environment-scoped library. + +# COMMAND ---------- + +# MAGIC %pip install dbldatagen + +# COMMAND ---------- + +import dbldatagen as dg +from pyspark.sql.types import * +from pyspark.sql import functions as F +from datetime import datetime, timedelta + +print(f"Using dbldatagen version: {dg.__version__}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Configuration +# MAGIC +# MAGIC **Best Practice**: Define all configuration parameters at the top for easy adjustment. + +# COMMAND ---------- + +# Data generation parameters - adjust these to scale up/down +NUM_PRODUCTS = 500 +NUM_DISTRIBUTION_CENTERS = 25 +NUM_STORES = 1000 +NUM_ORDERS = 10000 +NUM_INVENTORY_RECORDS = 50000 +NUM_SHIPMENTS = 30000 + +# Catalog configuration +CATALOG_NAME = 'CATALOG_NAME' +SCHEMA_NAME = 'SCHEMA_NAME' + +# Set up the Catalog +spark.sql(f"USE CATALOG {CATALOG_NAME}") +spark.sql(f"USE SCHEMA {SCHEMA_NAME}") + +print(f"Generating data in: {CATALOG_NAME}.{SCHEMA_NAME}") +print(f"Total records to generate: {NUM_PRODUCTS + NUM_DISTRIBUTION_CENTERS + NUM_STORES + NUM_ORDERS + NUM_INVENTORY_RECORDS + NUM_SHIPMENTS:,}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 1. Product Master Data +# MAGIC +# MAGIC ### Learning Objectives: +# MAGIC - How to use `withIdOutput()` for unique IDs +# MAGIC - Creating string expressions with `concat()` and `lpad()` +# MAGIC - Using `values` parameter for categorical data +# MAGIC - Working with different data types (string, decimal, integer, date) +# MAGIC +# MAGIC ### Key Concepts: +# MAGIC - **uniqueValues**: Ensures the column has exactly N unique values +# MAGIC - **template**: Generates random words (\\w pattern) +# MAGIC - **minValue/maxValue**: Range for numeric values +# MAGIC - **begin/end**: Date range parameters + +# COMMAND ---------- + +# Define categorical values for products +product_categories = ["Beverages", "Snacks", "Dairy", "Bakery", "Frozen Foods", "Personal Care", "Household"] +brands = ["Premium Brand A", "Value Brand B", "Store Brand C", "Organic Brand D", "Brand E"] + +# Build the data generator specification +products_spec = ( + dg.DataGenerator(spark, name="products", rows=NUM_PRODUCTS, partitions=4) + + # withIdOutput() creates an 'id' column with sequential integers starting at 1 + .withIdOutput() + + # Create SKU codes: SKU-000001, SKU-000002, etc. + # expr allows SQL expressions; cast(id as string) converts the id to string + # lpad pads to 6 digits; uniqueValues ensures exactly NUM_PRODUCTS unique SKUs + .withColumn("sku", "string", + expr="concat('SKU-', lpad(cast(id as string), 6, '0'))", + uniqueValues=NUM_PRODUCTS) + + # template uses \\w to generate random words + .withColumn("product_name", "string", template=r"\\w \\w Product") + + # values with random=True picks randomly from the list + .withColumn("category", "string", values=product_categories, random=True) + .withColumn("brand", "string", values=brands, random=True) + + # Numeric ranges for costs and pricing + .withColumn("unit_cost", "decimal(10,2)", minValue=0.5, maxValue=50.0, random=True) + .withColumn("unit_price", "decimal(10,2)", minValue=1.0, maxValue=100.0, random=True) + + # Pick from specific values (case sizes) + .withColumn("units_per_case", "integer", values=[6, 12, 24, 48], random=True) + .withColumn("weight_kg", "decimal(8,2)", minValue=0.1, maxValue=25.0, random=True) + .withColumn("shelf_life_days", "integer", minValue=30, maxValue=730, random=True) + + # Date range for when products were created + .withColumn("created_date", "date", begin="2020-01-01", + end="2024-01-01", interval="1 day", random=True ) +) + +# Build the dataframe from the specification +df_products = products_spec.build() + +# Write to table +df_products.write.mode("overwrite").saveAsTable("products") + +display(df_products.limit(10)) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 2. Distribution Centers +# MAGIC +# MAGIC ### Learning Objectives: +# MAGIC - Creating location codes with expressions +# MAGIC - Generating geographic coordinates (latitude/longitude) +# MAGIC - Using realistic ranges for capacity and utilization metrics +# MAGIC +# MAGIC ### Pro Tip: +# MAGIC When generating geographic data, use realistic ranges: +# MAGIC - US Latitude: 25.0 to 49.0 (southern border to Canadian border) +# MAGIC - US Longitude: -125.0 to -65.0 (west coast to east coast) + +# COMMAND ---------- + +distribution_center_spec = ( + dg.DataGenerator(spark, name="distribution_centers", rows=NUM_DISTRIBUTION_CENTERS, partitions=4) + .withIdOutput() + + # distribution_center codes: distribution_center-0001, distribution_center-0002, etc. + .withColumn("distribution_center_code", "string", + expr="concat('distribution_center-', lpad(cast(id as string), 4, '0'))", + uniqueValues=NUM_DISTRIBUTION_CENTERS) + + .withColumn("distribution_center_name", "string", template=r"\\w Distribution Center") + + # Regional distribution for US + .withColumn("region", "string", + values=["Northeast", "Southeast", "Midwest", "Southwest", "West"], + random=True) + + # Warehouse capacity metrics + .withColumn("capacity_pallets", "integer", minValue=5000, maxValue=50000, random=True) + .withColumn("current_utilization_pct", "decimal(5,2)", minValue=45.0, maxValue=95.0, random=True) + + # Geographic coordinates for mapping + .withColumn("latitude", "decimal(9,6)", minValue=25.0, maxValue=49.0, random=True) + .withColumn("longitude", "decimal(9,6)", minValue=-125.0, maxValue=-65.0, random=True) + + # Operating costs + .withColumn("operating_cost_daily", "decimal(10,2)", minValue=5000, maxValue=50000, random=True) + .withColumn("opened_date", "date", begin="2015-01-01", end="2023-01-01", random=True) +) + +df_distribution_centers = distribution_center_spec.build() +df_distribution_centers.write.mode("overwrite").saveAsTable("distribution_centers") + +print(f"Created distribution_centers table with {df_distribution_centers.count():,} records") +display(df_distribution_centers.limit(10)) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 3. Retail Stores +# MAGIC +# MAGIC ### Learning Objectives: +# MAGIC - Creating foreign key relationships (distribution_center_id references distribution_centers) +# MAGIC - Generating realistic store attributes +# MAGIC - Using longer store codes (6 digits vs 4 for distribution_centers) + +# COMMAND ---------- + +store_formats = ["Hypermarket", "Supermarket", "Convenience", "Online", "Club Store"] +retailers = ["RetailCo", "MegaMart", "QuickStop", "FreshGrocer", "ValueMart"] + +stores_spec = ( + dg.DataGenerator(spark, name="stores", rows=NUM_STORES, partitions=8) + .withIdOutput() + + # Store codes: STORE-000001, STORE-000002, etc. + .withColumn("store_code", "string", + expr="concat('STORE-', lpad(cast(id as string), 6, '0'))", + uniqueValues=NUM_STORES) + + .withColumn("retailer", "string", values=retailers, random=True) + .withColumn("store_format", "string", values=store_formats, random=True) + .withColumn("region", "string", + values=["Northeast", "Southeast", "Midwest", "Southwest", "West"], + random=True) + + # Store size range from small convenience to large hypermarket + .withColumn("square_footage", "integer", minValue=2000, maxValue=200000, random=True) + + # FOREIGN KEY: Links to distribution_centers table + # Each store gets a distribution_center ID between 1 and NUM_DISTRIBUTION_CENTERS + .withColumn("distribution_center_id", "integer", minValue=1, maxValue=NUM_DISTRIBUTION_CENTERS, random=True) + + .withColumn("latitude", "decimal(9,6)", minValue=25.0, maxValue=49.0, random=True) + .withColumn("longitude", "decimal(9,6)", minValue=-125.0, maxValue=-65.0, random=True) + .withColumn("opened_date", "date", begin="2010-01-01", end="2024-01-01", random=True) +) + +df_stores = stores_spec.build() +df_stores.write.mode("overwrite").saveAsTable("stores") + +print(f"Created stores table with {df_stores.count():,} records") +print(f"Each store is linked to a distribution_center via distribution_center_id foreign key") +display(df_stores.limit(10)) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 4. Orders +# MAGIC +# MAGIC ### Learning Objectives: +# MAGIC - Working with **timestamp** columns +# MAGIC - Using intermediate random columns for complex calculations +# MAGIC - Post-processing with PySpark transformations +# MAGIC - Using modulo operations for distributing categorical values +# MAGIC +# MAGIC ### Advanced Pattern: +# MAGIC When you need complex logic that depends on random values: +# MAGIC 1. Generate random "seed" columns in the spec +# MAGIC 2. Build the dataframe +# MAGIC 3. Use PySpark `.withColumn()` to create derived columns +# MAGIC 4. Drop the intermediate seed columns + +# COMMAND ---------- + +order_status = ["Scheduled", "In Progress", "Completed", "Delayed", "Quality Hold"] + +order_spec = ( + dg.DataGenerator(spark, name="orders", rows=NUM_ORDERS, partitions=8) + .withIdOutput() + + .withColumn("order_number", "string", + expr="concat('PO-', lpad(cast(id as string), 8, '0'))", + uniqueValues=NUM_ORDERS) + + # FOREIGN KEYS + .withColumn("distribution_center_id", "integer", minValue=1, maxValue=NUM_DISTRIBUTION_CENTERS, random=True) + .withColumn("product_id", "integer", minValue=1, maxValue=NUM_PRODUCTS, random=True) + + # Base timestamp for the order + .withColumn("order_date", "timestamp", + begin="2024-01-01 00:00:00", + end="2025-09-29 23:59:59", + random=True) + + # Random seed columns for calculations (will be used then dropped) + .withColumn("scheduled_start_days", "integer", minValue=0, maxValue=10, random=True) + .withColumn("scheduled_duration_days", "integer", minValue=1, maxValue=6, random=True) + .withColumn("start_delay_hours", "integer", minValue=-12, maxValue=12, random=True) + .withColumn("actual_duration_hours", "integer", minValue=24, maxValue=144, random=True) + .withColumn("start_probability", "double", minValue=0, maxValue=1, random=True) + .withColumn("completion_probability", "double", minValue=0, maxValue=1, random=True) + .withColumn("quantity_ordered", "integer", minValue=500, maxValue=50000, random=True) + .withColumn("order_variance", "double", minValue=0.85, maxValue=1.0, random=True) + + # Use modulo to distribute status values evenly + # status_rand % 5 gives values 0-4, which we'll map to our 5 status values + .withColumn("status_rand", "integer", minValue=1, maxValue=10000, random=True) + + .withColumn("line_efficiency_pct", "decimal(5,2)", minValue=75.0, maxValue=98.0, random=True) + .withColumn("production_cost", "decimal(12,2)", minValue=5000, maxValue=500000, random=True) +) + +# Build the base dataframe +df_orders = order_spec.build() + +# POST-PROCESSING: Add calculated columns using PySpark +df_orders = ( + df_orders + # Calculate scheduled start by adding days to order_date + .withColumn("scheduled_start", + F.expr("date_add(order_date, scheduled_start_days)")) + + # Calculate scheduled end + .withColumn("scheduled_end", + F.expr("date_add(scheduled_start, scheduled_duration_days)")) + + # Actual start: only if probability > 0.3, add delay hours + .withColumn("actual_start", + F.when(F.col("start_probability") > 0.3, + F.expr("timestampadd(HOUR, start_delay_hours, scheduled_start)")) + .otherwise(None)) + + # Actual end: only if started AND probability > 0.2 + .withColumn("actual_end", + F.when((F.col("actual_start").isNotNull()) & + (F.col("completion_probability") > 0.2), + F.expr("timestampadd(HOUR, actual_duration_hours, actual_start)")) + .otherwise(None)) + + # Quantity produced: apply variance if completed + .withColumn("quantity_produced", + F.when(F.col("actual_end").isNotNull(), + (F.col("quantity_ordered") * F.col("order_variance")).cast("integer")) + .otherwise(0)) + + # Map status_rand to status using modulo and array indexing + .withColumn("status_index", F.col("status_rand") % 5) + .withColumn("status", + F.array([F.lit(s) for s in order_status]).getItem(F.col("status_index"))) + + # Clean up: drop intermediate columns + .drop("scheduled_start_days", "scheduled_duration_days", "start_delay_hours", + "actual_duration_hours", "start_probability", "completion_probability", + "order_variance", "status_rand", "status_index") +) + +df_orders.write.mode("overwrite").saveAsTable("orders") + +print(f"Created orders table with {df_orders.count():,} records") +print(f"Order Status distribution:") +df_orders.groupBy("status").count().orderBy("status").show() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 5. Inventory Snapshots +# MAGIC +# MAGIC ### Learning Objectives: +# MAGIC - Using **CASE expressions** in SQL for conditional logic +# MAGIC - Creating weighted distributions with seed columns +# MAGIC - Handling division by zero with conditional logic +# MAGIC - Post-processing for complex foreign key relationships +# MAGIC +# MAGIC ### Pattern: Weighted Categorical Distribution +# MAGIC To get 30% distribution_center and 70% Store: +# MAGIC 1. Create a seed column with values 0-1 +# MAGIC 2. Use CASE: when seed < 0.3 then 'distribution_center' else 'Store' +# MAGIC +# MAGIC ### Pattern: Safe Division +# MAGIC Always check denominator before dividing to avoid errors + +# COMMAND ---------- + +inventory_spec = ( + dg.DataGenerator(spark, name="inventory", rows=NUM_INVENTORY_RECORDS, partitions=8) + .withIdOutput() + + # Date range for inventory snapshots + .withColumn("snapshot_date", "date", + begin="2024-01-01", + end="2025-09-29", + random=True) + + # Weighted distribution: 30% distribution_center, 70% Store + .withColumn("location_type_seed", "double", minValue=0, maxValue=1, random=True) + .withColumn("location_type", "string", expr=""" + CASE + WHEN location_type_seed < 0.3 THEN 'distribution_center' + ELSE 'Store' + END + """) + + # Create location_id based on location_type using expr + .withColumn("location_id", "integer", expr=""" + CASE + WHEN location_type = 'distribution_center' THEN (id % 25) + 1 + ELSE (id % 1000) + 1 + END + """) + + # FOREIGN KEY + .withColumn("product_id", "integer", minValue=1, maxValue=NUM_PRODUCTS, random=True) + + # Inventory quantities + .withColumn("quantity_on_hand", "integer", minValue=0, maxValue=10000, random=True) + .withColumn("reserve_factor", "double", minValue=0, maxValue=0.5, random=True) + + # Calculate reserved quantity using expr + .withColumn("quantity_reserved", "integer", expr="cast(quantity_on_hand * reserve_factor as int)") + + # Calculate available quantity + .withColumn("quantity_available", "integer", expr="quantity_on_hand - quantity_reserved") + + .withColumn("reorder_point", "integer", minValue=100, maxValue=2000, random=True) + + # Demand rate for calculations + .withColumn("daily_demand", "double", minValue=50.0, maxValue=150.0, random=True) + + # Calculate days of supply with safe division + .withColumn("days_of_supply", "decimal(8,2)", expr=""" + CASE + WHEN daily_demand > 0 THEN cast(quantity_available / daily_demand as decimal(8,2)) + ELSE NULL + END + """) + + .withColumn("inventory_value", "decimal(12,2)", minValue=1000, maxValue=500000, random=True) + .withColumn("days_offset", "integer", minValue=0, maxValue=60, random=True) + + # Date arithmetic using expr + .withColumn("last_received_date", "date", expr="date_sub(snapshot_date, days_offset)") + + # Risk categorization using expr + .withColumn("stockout_risk", "string", expr=""" + CASE + WHEN days_of_supply IS NULL OR days_of_supply < 3 THEN 'High' + WHEN days_of_supply < 7 THEN 'Medium' + ELSE 'Low' + END + """) +) + +# Build and drop intermediate columns +df_inventory = inventory_spec.build().drop("reserve_factor", "days_offset", "location_type_seed") + +df_inventory.write.mode("overwrite").saveAsTable("inventory") + +print(f"Created inventory table with {df_inventory.count():,} records") +print(f"Location type distribution:") +df_inventory.groupBy("location_type").count().show() +print(f"Stockout risk distribution:") +df_inventory.groupBy("stockout_risk").count().orderBy("stockout_risk").show() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 6. Shipments +# MAGIC +# MAGIC ### Learning Objectives: +# MAGIC - Creating **multiple weighted categorical columns** +# MAGIC - Working with date arithmetic (transit times) +# MAGIC - Computing derived metrics (on_time, delay_hours) +# MAGIC - Handling NULL values in calculations +# MAGIC +# MAGIC ### Pattern: Multiple Weighted Categories +# MAGIC For transport_mode with weights [60%, 15%, 20%, 5%]: +# MAGIC - 0.00-0.60: Truck (60%) +# MAGIC - 0.60-0.75: Rail (15%) +# MAGIC - 0.75-0.95: Intermodal (20%) +# MAGIC - 0.95-1.00: Air (5%) + +# COMMAND ---------- + +shipment_status = ["In Transit", "Delivered", "Delayed", "At Hub", "Out for Delivery"] +transport_modes = ["Truck", "Rail", "Intermodal", "Air"] + +shipments_spec = ( + dg.DataGenerator(spark, name="shipments", rows=NUM_SHIPMENTS, partitions=8) + .withIdOutput() + + .withColumn("shipment_id", "string", + expr="concat('SHP-', lpad(cast(id as string), 10, '0'))", + uniqueValues=NUM_SHIPMENTS) + + # FOREIGN KEY: Origin is always a distribution_center + .withColumn("origin_distribution_center_id", "integer", minValue=1, maxValue=NUM_DISTRIBUTION_CENTERS, random=True) + + # Destination can be distribution_center or Store (30% distribution_center, 70% Store) + .withColumn("destination_type_seed", "double", minValue=0, maxValue=1, random=True) + .withColumn("destination_type", "string", expr=""" + CASE + WHEN destination_type_seed < 0.3 THEN 'distribution_center' + ELSE 'Store' + END + """) + + # Create destination_id based on destination_type + .withColumn("destination_id", "integer", expr=""" + CASE + WHEN destination_type = 'distribution_center' THEN (id % 25) + 1 + ELSE (id % 1000) + 1 + END + """) + + .withColumn("product_id", "integer", minValue=1, maxValue=NUM_PRODUCTS, random=True) + + # Shipment dates + .withColumn("ship_date", "timestamp", + begin="2024-01-01 00:00:00", + end="2025-09-29 23:59:59", + random=True) + + # Transit time ranges + .withColumn("transit_days", "integer", minValue=1, maxValue=6, random=True) + .withColumn("actual_transit_days", "integer", minValue=1, maxValue=8, random=True) + .withColumn("delivery_probability", "double", minValue=0, maxValue=1, random=True) + + # Expected delivery = ship_date + transit_days (using date_add) + .withColumn("expected_delivery", "timestamp", expr="date_add(ship_date, transit_days)") + + # Actual delivery: only 80% of shipments are delivered + .withColumn("actual_delivery", "timestamp", expr=""" + CASE + WHEN delivery_probability > 0.2 THEN date_add(ship_date, actual_transit_days) + ELSE NULL + END + """) + + # On-time check: delivered AND before/at expected time + .withColumn("on_time", "boolean", expr=""" + actual_delivery IS NOT NULL AND actual_delivery <= expected_delivery + """) + + # Calculate delay in hours (can be negative for early deliveries) + .withColumn("delay_hours", "integer", expr=""" + CASE + WHEN actual_delivery IS NOT NULL THEN + cast((unix_timestamp(actual_delivery) - unix_timestamp(expected_delivery)) / 3600 as int) + ELSE NULL + END + """) + + .withColumn("quantity", "integer", minValue=100, maxValue=5000, random=True) + + # Transport mode with weighted distribution: 60% Truck, 15% Rail, 20% Intermodal, 5% Air + .withColumn("transport_mode_seed", "double", minValue=0, maxValue=1, random=True) + .withColumn("transport_mode", "string", expr=""" + CASE + WHEN transport_mode_seed < 0.60 THEN 'Truck' + WHEN transport_mode_seed < 0.75 THEN 'Rail' + WHEN transport_mode_seed < 0.95 THEN 'Intermodal' + ELSE 'Air' + END + """) + + .withColumn("carrier", "string", + values=["FastFreight", "ReliableLogistics", "ExpressTransport", "GlobalShippers"], + random=True) + + # Status with weighted distribution: 25% In Transit, 50% Delivered, 5% Delayed, 10% At Hub, 10% Out for Delivery + .withColumn("status_seed", "double", minValue=0, maxValue=1, random=True) + .withColumn("status", "string", expr=""" + CASE + WHEN status_seed < 0.25 THEN 'In Transit' + WHEN status_seed < 0.75 THEN 'Delivered' + WHEN status_seed < 0.80 THEN 'Delayed' + WHEN status_seed < 0.90 THEN 'At Hub' + ELSE 'Out for Delivery' + END + """) + + .withColumn("shipping_cost", "decimal(10,2)", minValue=50, maxValue=5000, random=True) + .withColumn("distance_miles", "integer", minValue=50, maxValue=2500, random=True) +) + +# Build and drop intermediate columns +df_shipments = shipments_spec.build().drop( + "transit_days", "actual_transit_days", "delivery_probability", + "destination_type_seed", "transport_mode_seed", "status_seed" +) + +df_shipments.write.mode("overwrite").saveAsTable("shipments") + +print(f"Created shipments table with {df_shipments.count():,} records") +print(f"Transport mode distribution:") +df_shipments.groupBy("transport_mode").count().orderBy(F.desc("count")).show() +print(f"Shipment status distribution:") +df_shipments.groupBy("status").count().orderBy(F.desc("count")).show() +display(df_shipments) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Demo Use Cases +# MAGIC +# MAGIC This dataset enables the following analytics use cases: +# MAGIC +# MAGIC ### Inventory Optimization +# MAGIC - Stockout risk identification and prediction +# MAGIC - Days of supply analysis by product/location +# MAGIC - Slow-moving inventory identification +# MAGIC +# MAGIC ### Logistics & Transportation +# MAGIC - Carrier performance scorecards (OTD%, cost, speed) +# MAGIC - Route optimization opportunities +# MAGIC - Transport mode analysis (cost vs speed tradeoffs) +# MAGIC +# MAGIC ### Order Planning +# MAGIC - Order schedule optimization +# MAGIC - Line efficiency tracking +# MAGIC - Capacity planning and utilization +# MAGIC +# MAGIC ### Supply Chain Analytics +# MAGIC - End-to-end supply chain visibility +# MAGIC - Network optimization (distribution_center placement, capacity) +# MAGIC - Working capital optimization +# MAGIC +# MAGIC ### AI/ML Use Cases +# MAGIC - Demand forecasting +# MAGIC - Predictive maintenance (production efficiency) +# MAGIC - Shipment delay prediction + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Sample Queries to Get Started +# MAGIC +# MAGIC Here are some queries you can run to explore the data. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Query 1: Current Inventory Health + +# COMMAND ---------- + +# MAGIC %sql +# MAGIC -- Inventory health by location type and risk level +# MAGIC SELECT +# MAGIC location_type, +# MAGIC stockout_risk, +# MAGIC COUNT(*) as item_count, +# MAGIC SUM(inventory_value) as total_value, +# MAGIC ROUND(AVG(days_of_supply), 1) as avg_days_supply +# MAGIC FROM inventory +# MAGIC WHERE snapshot_date = (SELECT MAX(snapshot_date) FROM inventory) +# MAGIC GROUP BY location_type, stockout_risk +# MAGIC ORDER BY location_type, +# MAGIC CASE stockout_risk +# MAGIC WHEN 'High' THEN 1 +# MAGIC WHEN 'Medium' THEN 2 +# MAGIC WHEN 'Low' THEN 3 +# MAGIC END + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Query 2: Carrier Performance Comparison + +# COMMAND ---------- + +# MAGIC %sql +# MAGIC -- Compare carriers on key metrics +# MAGIC SELECT +# MAGIC carrier, +# MAGIC COUNT(*) as total_shipments, +# MAGIC ROUND(AVG(CASE WHEN on_time = true THEN 100.0 ELSE 0.0 END), 1) as otd_pct, +# MAGIC ROUND(AVG(shipping_cost), 2) as avg_cost, +# MAGIC ROUND(AVG(distance_miles), 0) as avg_distance, +# MAGIC ROUND(AVG(shipping_cost / distance_miles), 3) as cost_per_mile +# MAGIC FROM shipments +# MAGIC WHERE actual_delivery IS NOT NULL +# MAGIC GROUP BY carrier +# MAGIC ORDER BY total_shipments DESC + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Query 3: Supply Chain Network Overview + +# COMMAND ---------- + +# MAGIC %sql +# MAGIC -- distribution_center performance and utilization +# MAGIC SELECT +# MAGIC distribution_center.distribution_center_code, +# MAGIC distribution_center.region, +# MAGIC distribution_center.capacity_pallets, +# MAGIC ROUND(distribution_center.current_utilization_pct, 1) as utilization_pct, +# MAGIC COUNT(DISTINCT i.product_id) as active_skus, +# MAGIC SUM(i.inventory_value) as inventory_value, +# MAGIC COUNT(DISTINCT s.id) as outbound_shipments_last_30d, +# MAGIC ROUND(AVG(CASE WHEN s.on_time = true THEN 100.0 ELSE 0.0 END), 1) as otd_pct +# MAGIC FROM distribution_centers distribution_center +# MAGIC LEFT JOIN inventory i ON distribution_center.id = i.location_id +# MAGIC AND i.location_type = 'distribution_center' +# MAGIC AND i.snapshot_date = (SELECT MAX(snapshot_date) FROM inventory) +# MAGIC LEFT JOIN shipments s ON distribution_center.id = s.origin_distribution_center_id +# MAGIC AND s.ship_date >= CURRENT_DATE - INTERVAL 30 DAY +# MAGIC GROUP BY distribution_center.distribution_center_code, distribution_center.region, distribution_center.capacity_pallets, distribution_center.current_utilization_pct +# MAGIC ORDER BY inventory_value DESC + +# COMMAND ---------- + From da8bc16711452dbcefe46627984ec112d39258f3 Mon Sep 17 00:00:00 2001 From: Anup Kalburgi Date: Wed, 10 Dec 2025 10:24:25 -0500 Subject: [PATCH 17/20] renaming the keys and falling back to absolute imports --- dbldatagen/spec/__init__.py | 2 +- dbldatagen/spec/column_spec.py | 62 ++++++------------- dbldatagen/spec/compat.py | 5 +- dbldatagen/spec/generator_spec.py | 14 ++--- dbldatagen/spec/generator_spec_impl.py | 12 ++-- dbldatagen/types.py | 26 ++++++++ .../basic_stock_ticker_datagen_spec.py | 4 +- .../basic_user_datagen_spec.py | 4 +- tests/test_datagen_specs.py | 32 +++++----- tests/test_specs.py | 43 +++++++------ tests/test_standard_datasets.py | 2 +- 11 files changed, 105 insertions(+), 101 deletions(-) create mode 100644 dbldatagen/types.py diff --git a/dbldatagen/spec/__init__.py b/dbldatagen/spec/__init__.py index afede3f6..a22fb217 100644 --- a/dbldatagen/spec/__init__.py +++ b/dbldatagen/spec/__init__.py @@ -11,7 +11,7 @@ from typing import Any # Import only the compat layer by default to avoid triggering Spark/heavy dependencies -from .compat import BaseModel, Field, constr, root_validator, validator +from dbldatagen.spec.compat import BaseModel, Field, constr, root_validator, validator # Lazy imports for heavy modules - import these explicitly when needed diff --git a/dbldatagen/spec/column_spec.py b/dbldatagen/spec/column_spec.py index c6fa20f8..3d81c462 100644 --- a/dbldatagen/spec/column_spec.py +++ b/dbldatagen/spec/column_spec.py @@ -1,58 +1,33 @@ from __future__ import annotations -from typing import Any, Literal - -from .compat import BaseModel, root_validator - - -DbldatagenBasicType = Literal[ - "string", - "int", - "long", - "float", - "double", - "decimal", - "boolean", - "date", - "timestamp", - "short", - "byte", - "binary", - "integer", - "bigint", - "tinyint", -] -"""Type alias representing supported basic Spark SQL data types for column definitions. - -Includes both standard SQL types (e.g. string, int, double) and Spark-specific type names -(e.g. bigint, tinyint). These types are used in the ColumnDefinition to specify the data type -for generated columns. -""" +from typing import Any + +from dbldatagen.spec.compat import BaseModel, root_validator +from dbldatagen.types import DbldatagenBasicType class ColumnDefinition(BaseModel): """Defines the specification for a single column in a synthetic data table. - This class encapsulates all the information needed to generate data for a single column, - including its name, type, constraints, and generation options. It supports both primary key - columns and derived columns that can reference other columns. + It supports primary key columns, data columns, and derived columns that reference other columns. - :param name: Name of the column to be generated + :param name: Name of the column to be generated (required) :param type: Spark SQL data type for the column (e.g., "string", "int", "timestamp"). - If None, type may be inferred from options or baseColumn + If None, type may be inferred from options or baseColumn. Defaults to None :param primary: If True, this column will be treated as a primary key column with unique values. - Primary columns cannot have min/max options and cannot be nullable + Primary columns cannot have min/max options and cannot be nullable. Defaults to False :param options: Dictionary of additional options controlling column generation behavior. Common options include: min, max, step, values, template, distribution, etc. - See dbldatagen documentation for full list of available options - :param nullable: If True, the column may contain NULL values. Primary columns cannot be nullable + See dbldatagen documentation for full list of available options. Defaults to None + :param nullable: If True, the column may contain NULL values. Primary columns cannot be nullable. + Defaults to False :param omit: If True, this column will be generated internally but excluded from the final output. - Useful for intermediate columns used in calculations + Useful for intermediate columns used in calculations. Defaults to False :param baseColumn: Name of another column to use as the basis for generating this column's values. - Default is "id" which refers to the internal row identifier + Defaults to "id" which refers to the internal row identifier :param baseColumnType: Method for deriving values from the baseColumn. Common values: "auto" (infer behavior), "hash" (hash the base column values), - "values" (use base column values directly) + "values" (use base column values directly). Defaults to "auto" .. warning:: Experimental - This API is subject to change in future versions @@ -84,7 +59,8 @@ def check_model_constraints(cls, values: dict[str, Any]) -> dict[str, Any]: constraints that depend on multiple fields being set. It ensures that primary key columns meet all necessary requirements and that conflicting options are not specified. - :param values: Dictionary of all field values for this ColumnDefinition instance + :param cls: The ColumnDefinition class (automatically provided by Pydantic) + :param values: Dictionary of all field values for this ColumnDefinition parameters :returns: The validated values dictionary, unmodified if all validations pass :raises ValueError: If primary column has min/max options, or if primary column is nullable, or if primary column doesn't have a type defined @@ -100,11 +76,11 @@ def check_model_constraints(cls, values: dict[str, Any]) -> dict[str, Any]: if is_primary: if "min" in options or "max" in options: - raise ValueError(f"Primary column '{name}' cannot have min/max options.") + raise ValueError(f"Primary key column '{name}' cannot have min/max options.") if is_nullable: - raise ValueError(f"Primary column '{name}' cannot be nullable.") + raise ValueError(f"Primary key column '{name}' cannot be nullable.") if column_type is None: - raise ValueError(f"Primary column '{name}' must have a type defined.") + raise ValueError(f"Primary key column '{name}' must have a type defined.") return values diff --git a/dbldatagen/spec/compat.py b/dbldatagen/spec/compat.py index 72215c0c..72f9e63b 100644 --- a/dbldatagen/spec/compat.py +++ b/dbldatagen/spec/compat.py @@ -16,7 +16,7 @@ Always import from this compat module, not directly from pydantic:: # Correct - from .compat import BaseModel, validator + from dbldatagen.spec.compat import BaseModel, validator # Incorrect - don't do this from pydantic import BaseModel, validator @@ -26,6 +26,9 @@ - **Pydantic V1.x environments**: Imports directly from pydantic package - **Databricks runtimes**: Works with pre-installed Pydantic versions without conflicts + - **DBR 16.4 onwards**: Pydantic 2.8+ available + - **DBR 15.4 and below**: Pydantic 1.10.6 available + .. note:: This approach is inspired by FastAPI's compatibility layer: https://github.com/fastapi/fastapi/blob/master/fastapi/_compat.py diff --git a/dbldatagen/spec/generator_spec.py b/dbldatagen/spec/generator_spec.py index 386178cd..42976ee9 100644 --- a/dbldatagen/spec/generator_spec.py +++ b/dbldatagen/spec/generator_spec.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -class TableDefinition(BaseModel): +class DatasetDefinition(BaseModel): """Defines the complete specification for a single synthetic data table. This class encapsulates all the information needed to generate a table of synthetic data, @@ -46,12 +46,12 @@ class TableDefinition(BaseModel): class DatagenSpec(BaseModel): - """Top-level specification for synthetic data generation across one or more tables. + """Top-level specification for synthetic data generation across one or more datasets. This is the main configuration class for the dbldatagen spec-based API. It defines all tables to be generated, where the output should be written, and global generation options. - :param tables: Dictionary mapping table names to their TableDefinition specifications. + :param datasets: Dictionary mapping table names to their DatasetDefinition specifications. Keys are the table names that will be used in the output destination :param output_destination: Target location for generated data. Can be either a UCSchemaTarget (Unity Catalog) or FilePathTarget (file system). @@ -77,7 +77,7 @@ class DatagenSpec(BaseModel): Multiple tables can share the same DatagenSpec and will be generated in the order they appear in the tables dictionary """ - tables: dict[str, TableDefinition] + datasets: dict[str, DatasetDefinition] output_destination: Union[UCSchemaTarget, FilePathTarget] | None = None # there is a abstraction, may be we can use that? talk to Greg generator_options: dict[str, Any] | None = None intended_for_databricks: bool | None = None # May be infered. @@ -171,11 +171,11 @@ def validate(self, strict: bool = True) -> ValidationResult: # type: ignore[ove result = ValidationResult() # 1. Check that there's at least one table - if not self.tables: + if not self.datasets: result.add_error("Spec must contain at least one table definition") # 2. Validate each table (continue checking all tables even if errors found) - for table_name, table_def in self.tables.items(): + for table_name, table_def in self.datasets.items(): # Check table has at least one column if not table_def.columns: result.add_error(f"Table '{table_name}' must have at least one column") @@ -283,7 +283,7 @@ def display_all_tables(self) -> None: .. note:: This is intended for interactive exploration and debugging of spec configurations """ - for table_name, table_def in self.tables.items(): + for table_name, table_def in self.datasets.items(): print(f"Table: {table_name}") if self.output_destination: diff --git a/dbldatagen/spec/generator_spec_impl.py b/dbldatagen/spec/generator_spec_impl.py index fc56699b..ac42204f 100644 --- a/dbldatagen/spec/generator_spec_impl.py +++ b/dbldatagen/spec/generator_spec_impl.py @@ -5,7 +5,7 @@ from pyspark.sql import SparkSession import dbldatagen as dg -from dbldatagen.spec.generator_spec import TableDefinition +from dbldatagen.spec.generator_spec import DatasetDefinition from .generator_spec import ColumnDefinition, DatagenSpec, FilePathTarget, UCSchemaTarget @@ -161,7 +161,7 @@ def _prepareDataGenerators( DataGenerators before data generation begins """ logger.info( - f"Preparing data generators for {len(config.tables)} tables") + f"Preparing data generators for {len(config.datasets)} tables") if not self.spark: logger.error( @@ -169,11 +169,11 @@ def _prepareDataGenerators( raise RuntimeError( "SparkSession is not available. Cannot prepare data generators") - tables_config: dict[str, TableDefinition] = config.tables + tables_config: dict[str, DatasetDefinition] = config.datasets global_gen_options = config.generator_options if config.generator_options else {} prepared_generators: dict[str, dg.DataGenerator] = {} - generation_order = list(tables_config.keys()) # This becomes impotant when we get into multitable + generation_order = list(tables_config.keys()) # This becomes important when we get into multitable for table_name in generation_order: table_spec = tables_config[table_name] @@ -322,13 +322,13 @@ def generateAndWriteData( >>> generator = Generator(spark) >>> generator.generateAndWriteData(spec) """ - logger.info(f"Starting combined data generation and writing for {len(config.tables)} tables") + logger.info(f"Starting combined data generation and writing for {len(config.datasets)} tables") try: # Phase 1: Prepare data generators prepared_generators_map = self._prepareDataGenerators(config, config_source_name) - if not prepared_generators_map and list(config.tables.keys()): + if not prepared_generators_map and list(config.datasets.keys()): logger.warning( "No data generators were successfully prepared, though tables were defined") return diff --git a/dbldatagen/types.py b/dbldatagen/types.py new file mode 100644 index 00000000..e72a30f0 --- /dev/null +++ b/dbldatagen/types.py @@ -0,0 +1,26 @@ +from typing import Literal + + +DbldatagenBasicType = Literal[ + "string", + "int", + "long", + "float", + "double", + "decimal", + "boolean", + "date", + "timestamp", + "short", + "byte", + "binary", + "integer", + "bigint", + "tinyint", +] +"""Type alias representing supported basic Spark SQL data types for column definitions. + +Includes both standard SQL types (e.g. string, int, double) and Spark-specific type names +(e.g. bigint, tinyint). These types are used in the ColumnDefinition to specify the data type +for generated columns. +""" diff --git a/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py b/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py index 5cc18edd..1095d270 100644 --- a/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py +++ b/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py @@ -6,7 +6,7 @@ from random import random -from dbldatagen.spec.generator_spec import DatagenSpec, TableDefinition +from dbldatagen.spec.generator_spec import DatagenSpec, DatasetDefinition from dbldatagen.spec.column_spec import ColumnDefinition @@ -266,7 +266,7 @@ def create_basic_stock_ticker_spec( ), ] - table_def = TableDefinition( + table_def = DatasetDefinition( number_of_rows=number_of_rows, partitions=partitions, columns=columns diff --git a/examples/datagen_from_specs/basic_user_datagen_spec.py b/examples/datagen_from_specs/basic_user_datagen_spec.py index ef0077e2..450f9b9e 100644 --- a/examples/datagen_from_specs/basic_user_datagen_spec.py +++ b/examples/datagen_from_specs/basic_user_datagen_spec.py @@ -4,7 +4,7 @@ the basic user dataset, corresponding to the BasicUserProvider. """ -from dbldatagen.spec.generator_spec import DatagenSpec, TableDefinition +from dbldatagen.spec.generator_spec import DatagenSpec, DatasetDefinition from dbldatagen.spec.column_spec import ColumnDefinition @@ -77,7 +77,7 @@ def create_basic_user_spec( ), ] - table_def = TableDefinition( + table_def = DatasetDefinition( number_of_rows=number_of_rows, partitions=partitions, columns=columns diff --git a/tests/test_datagen_specs.py b/tests/test_datagen_specs.py index 105c3eeb..589d5e6b 100644 --- a/tests/test_datagen_specs.py +++ b/tests/test_datagen_specs.py @@ -3,7 +3,7 @@ import unittest # Import DatagenSpec classes directly to avoid Spark initialization -from dbldatagen.spec.generator_spec import DatagenSpec, TableDefinition +from dbldatagen.spec.generator_spec import DatagenSpec, DatasetDefinition from dbldatagen.spec.column_spec import ColumnDefinition @@ -30,7 +30,7 @@ def test_basic_user_spec_creation(self): ), ] - table_def = TableDefinition( + table_def = DatasetDefinition( number_of_rows=1000, partitions=2, columns=columns @@ -42,10 +42,10 @@ def test_basic_user_spec_creation(self): ) self.assertIsNotNone(spec) - self.assertIn("users", spec.tables) - self.assertEqual(spec.tables["users"].number_of_rows, 1000) - self.assertEqual(spec.tables["users"].partitions, 2) - self.assertEqual(len(spec.tables["users"].columns), 3) + self.assertIn("users", spec.datasets) + self.assertEqual(spec.datasets["users"].number_of_rows, 1000) + self.assertEqual(spec.datasets["users"].partitions, 2) + self.assertEqual(len(spec.datasets["users"].columns), 3) def test_basic_user_spec_validation(self): """Test validating a basic user DatagenSpec.""" @@ -62,7 +62,7 @@ def test_basic_user_spec_validation(self): ), ] - table_def = TableDefinition( + table_def = DatasetDefinition( number_of_rows=100, columns=columns ) @@ -92,7 +92,7 @@ def test_column_with_base_column(self): ), ] - table_def = TableDefinition( + table_def = DatasetDefinition( number_of_rows=50, columns=columns ) @@ -138,7 +138,7 @@ def test_basic_stock_ticker_spec_creation(self): ), ] - table_def = TableDefinition( + table_def = DatasetDefinition( number_of_rows=1000, partitions=2, columns=columns @@ -150,9 +150,9 @@ def test_basic_stock_ticker_spec_creation(self): ) self.assertIsNotNone(spec) - self.assertIn("stock_tickers", spec.tables) - self.assertEqual(spec.tables["stock_tickers"].number_of_rows, 1000) - self.assertEqual(len(spec.tables["stock_tickers"].columns), 5) + self.assertIn("stock_tickers", spec.datasets) + self.assertEqual(spec.datasets["stock_tickers"].number_of_rows, 1000) + self.assertEqual(len(spec.datasets["stock_tickers"].columns), 5) def test_stock_ticker_with_omitted_columns(self): """Test creating spec with omitted intermediate columns.""" @@ -175,7 +175,7 @@ def test_stock_ticker_with_omitted_columns(self): ), ] - table_def = TableDefinition( + table_def = DatasetDefinition( number_of_rows=100, columns=columns ) @@ -213,7 +213,7 @@ def test_duplicate_column_names(self): ColumnDefinition(name="id", type="string"), # Duplicate! ] - table_def = TableDefinition( + table_def = DatasetDefinition( number_of_rows=100, columns=columns ) @@ -235,7 +235,7 @@ def test_negative_rows_validation(self): ] # Create with negative rows using dict to bypass Pydantic validation - table_def = TableDefinition( + table_def = DatasetDefinition( number_of_rows=-100, # Invalid columns=columns ) @@ -255,7 +255,7 @@ def test_spec_with_generator_options(self): ColumnDefinition(name="value", type="long") ] - table_def = TableDefinition( + table_def = DatasetDefinition( number_of_rows=100, columns=columns ) diff --git a/tests/test_specs.py b/tests/test_specs.py index 87f40c58..4e994c5f 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -1,8 +1,7 @@ -from dbldatagen.spec.generator_spec import DatagenSpec import pytest from dbldatagen.spec.generator_spec import ( DatagenSpec, - TableDefinition, + DatasetDefinition, ColumnDefinition, UCSchemaTarget, FilePathTarget, @@ -103,7 +102,7 @@ class TestDatagenSpecValidation: def test_valid_spec_passes_validation(self): spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=100, columns=[ ColumnDefinition(name="id", type="int", primary=True), @@ -127,7 +126,7 @@ def test_empty_tables_raises_error(self): def test_table_without_columns_raises_error(self): spec = DatagenSpec( tables={ - "empty_table": TableDefinition( + "empty_table": DatasetDefinition( number_of_rows=100, columns=[] ) @@ -140,7 +139,7 @@ def test_table_without_columns_raises_error(self): def test_negative_row_count_raises_error(self): spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=-10, columns=[ColumnDefinition(name="id", type="int", primary=True)] ) @@ -153,7 +152,7 @@ def test_negative_row_count_raises_error(self): def test_zero_row_count_raises_error(self): spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=0, columns=[ColumnDefinition(name="id", type="int", primary=True)] ) @@ -166,7 +165,7 @@ def test_zero_row_count_raises_error(self): def test_invalid_partitions_raises_error(self): spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=100, partitions=-5, columns=[ColumnDefinition(name="id", type="int", primary=True)] @@ -180,7 +179,7 @@ def test_invalid_partitions_raises_error(self): def test_duplicate_column_names_raises_error(self): spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=100, columns=[ ColumnDefinition(name="id", type="int", primary=True), @@ -197,7 +196,7 @@ def test_duplicate_column_names_raises_error(self): def test_invalid_base_column_reference_raises_error(self): spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=100, columns=[ ColumnDefinition(name="id", type="int", primary=True), @@ -213,7 +212,7 @@ def test_invalid_base_column_reference_raises_error(self): def test_circular_dependency_raises_error(self): spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=100, columns=[ ColumnDefinition(name="id", type="int", primary=True), @@ -231,7 +230,7 @@ def test_circular_dependency_raises_error(self): def test_multiple_primary_columns_warning(self): spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=100, columns=[ ColumnDefinition(name="id1", type="int", primary=True), @@ -254,7 +253,7 @@ def test_multiple_primary_columns_warning(self): def test_column_without_type_or_options_warning(self): spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=100, columns=[ ColumnDefinition(name="id", type="int", primary=True), @@ -272,7 +271,7 @@ def test_column_without_type_or_options_warning(self): def test_no_output_destination_warning(self): spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=100, columns=[ColumnDefinition(name="id", type="int", primary=True)] ) @@ -287,7 +286,7 @@ def test_no_output_destination_warning(self): def test_unknown_generator_option_warning(self): spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=100, columns=[ColumnDefinition(name="id", type="int", primary=True)] ) @@ -304,7 +303,7 @@ def test_multiple_errors_collected(self): """Test that all errors are collected before raising""" spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=-10, # Error 1 partitions=0, # Error 2 columns=[ @@ -329,7 +328,7 @@ def test_multiple_errors_collected(self): def test_strict_mode_raises_on_warnings(self): spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=100, columns=[ColumnDefinition(name="id", type="int", primary=True)] ) @@ -349,7 +348,7 @@ def test_valid_base_column_chain(self): """Test that valid baseColumn chains work""" spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=100, columns=[ ColumnDefinition(name="id", type="int", primary=True), @@ -368,15 +367,15 @@ def test_multiple_tables_validation(self): """Test validation across multiple tables""" spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=100, columns=[ColumnDefinition(name="id", type="int", primary=True)] ), - "orders": TableDefinition( + "orders": DatasetDefinition( number_of_rows=-50, # Error in second table columns=[ColumnDefinition(name="order_id", type="int", primary=True)] ), - "products": TableDefinition( + "products": DatasetDefinition( number_of_rows=200, columns=[] # Error: no columns ) @@ -425,7 +424,7 @@ def test_realistic_valid_spec(self): """Test a realistic, valid specification""" spec = DatagenSpec( tables={ - "users": TableDefinition( + "users": DatasetDefinition( number_of_rows=1000, partitions=4, columns=[ @@ -441,7 +440,7 @@ def test_realistic_valid_spec(self): }), ] ), - "orders": TableDefinition( + "orders": DatasetDefinition( number_of_rows=5000, columns=[ ColumnDefinition(name="order_id", type="int", primary=True), diff --git a/tests/test_standard_datasets.py b/tests/test_standard_datasets.py index 90afac04..e8272b4f 100644 --- a/tests/test_standard_datasets.py +++ b/tests/test_standard_datasets.py @@ -190,7 +190,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=1000000, parti ds_definition = X1.getDatasetDefinition() print("ds_definition", ds_definition) assert ds_definition.name == "providers/X1" - assert ds_definition.tables == [DatasetProvider.DEFAULT_TABLE_NAME] + assert ds_definition.datasets == [DatasetProvider.DEFAULT_TABLE_NAME] assert ds_definition.primaryTable == DatasetProvider.DEFAULT_TABLE_NAME assert ds_definition.summary is not None assert ds_definition.description is not None From 99d55266beba57aa3aaa86650540bb166bf23fe2 Mon Sep 17 00:00:00 2001 From: Anup Kalburgi Date: Wed, 10 Dec 2025 14:35:48 -0500 Subject: [PATCH 18/20] making methods static, and fixing tests --- dbldatagen/spec/generator_spec.py | 316 +++++++++++------- dbldatagen/spec/generator_spec_impl.py | 143 ++++---- .../basic_stock_ticker_datagen_spec.py | 8 +- tests/test_datagen_specs.py | 18 +- tests/test_specs.py | 36 +- tests/test_standard_datasets.py | 2 +- 6 files changed, 301 insertions(+), 222 deletions(-) diff --git a/dbldatagen/spec/generator_spec.py b/dbldatagen/spec/generator_spec.py index 42976ee9..cfcba822 100644 --- a/dbldatagen/spec/generator_spec.py +++ b/dbldatagen/spec/generator_spec.py @@ -16,8 +16,177 @@ logger = logging.getLogger(__name__) +def _validate_table_basic_properties( + table_name: str, + table_def: DatasetDefinition, + result: ValidationResult +) -> bool: + """Validate basic table properties like columns, row count, and partitions. + + :param table_name: Name of the table being validated + :param table_def: DatasetDefinition object to validate + :param result: ValidationResult to collect errors/warnings + :returns: True if table has columns and can proceed with further validation, + False if table has no columns and should skip further checks + """ + # Check table has at least one column + if not table_def.columns: + result.add_error(f"Table '{table_name}' must have at least one column") + return False + + # Check row count is positive + if table_def.number_of_rows <= 0: + result.add_error( + f"Table '{table_name}' has invalid number_of_rows: {table_def.number_of_rows}. " + "Must be a positive integer." + ) + + # Check partitions if specified + if table_def.partitions is not None and table_def.partitions <= 0: + result.add_error( + f"Table '{table_name}' has invalid partitions: {table_def.partitions}. " + "Must be a positive integer or None." + ) + + return True + + +def _validate_duplicate_columns( + table_name: str, + table_def: DatasetDefinition, + result: ValidationResult +) -> None: + """Check for duplicate column names within a table. + + :param table_name: Name of the table being validated + :param table_def: DatasetDefinition object to validate + :param result: ValidationResult to collect errors/warnings + """ + column_names = [col.name for col in table_def.columns] + duplicates = [name for name in set(column_names) if column_names.count(name) > 1] + if duplicates: + result.add_error( + f"Table '{table_name}' has duplicate column names: {', '.join(duplicates)}" + ) + + +def _validate_column_references( + table_name: str, + table_def: DatasetDefinition, + result: ValidationResult +) -> None: + """Validate that baseColumn references point to existing columns. + + :param table_name: Name of the table being validated + :param table_def: DatasetDefinition object to validate + :param result: ValidationResult to collect errors/warnings + """ + column_map = {col.name: col for col in table_def.columns} + for col in table_def.columns: + if col.baseColumn and col.baseColumn != "id": + if col.baseColumn not in column_map: + result.add_error( + f"Table '{table_name}', column '{col.name}': " + f"baseColumn '{col.baseColumn}' does not exist in the table" + ) + + +def _validate_primary_key_columns( + table_name: str, + table_def: DatasetDefinition, + result: ValidationResult +) -> None: + """Validate primary key column constraints. + + :param table_name: Name of the table being validated + :param table_def: DatasetDefinition object to validate + :param result: ValidationResult to collect errors/warnings + """ + primary_columns = [col for col in table_def.columns if col.primary] + if len(primary_columns) > 1: + primary_names = [col.name for col in primary_columns] + result.add_warning( + f"Table '{table_name}' has multiple primary columns: {', '.join(primary_names)}. " + "This may not be the intended behavior." + ) + + +def _validate_column_types( + table_name: str, + table_def: DatasetDefinition, + result: ValidationResult +) -> None: + """Validate column type specifications. + + :param table_name: Name of the table being validated + :param table_def: DatasetDefinition object to validate + :param result: ValidationResult to collect errors/warnings + """ + for col in table_def.columns: + if not col.primary and not col.type and not col.options: + result.add_warning( + f"Table '{table_name}', column '{col.name}': " + "No type specified and no options provided. " + "Column may not generate data as expected." + ) + + +def _check_circular_dependencies( + table_name: str, + columns: list[ColumnDefinition] +) -> list[str]: + """Check for circular dependencies in baseColumn references within a table. + + Analyzes column dependencies to detect cycles where columns reference each other + in a circular manner (e.g., col A depends on col B, col B depends on col A). + Such circular dependencies would make data generation impossible. + + :param table_name: Name of the table being validated (used in error messages) + :param columns: List of ColumnDefinition objects to check for circular dependencies + :returns: List of error message strings describing any circular dependencies found. + Empty list if no circular dependencies exist + + .. note:: + This function performs a graph traversal to detect cycles in the dependency chain + """ + errors = [] + column_map = {col.name: col for col in columns} + + for col in columns: + if col.baseColumn and col.baseColumn != "id": + # Track the dependency chain + visited: set[str] = set() + current = col.name + + while current: + if current in visited: + # Found a cycle + cycle_path = " -> ".join([*list(visited), current]) + errors.append( + f"Table '{table_name}': Circular dependency detected in column '{col.name}': {cycle_path}" + ) + break + + visited.add(current) + current_col = column_map.get(current) + + if not current_col: + break + + # Move to the next column in the chain + if current_col.baseColumn and current_col.baseColumn != "id": + if current_col.baseColumn not in column_map: + # This will be caught by _validate_column_references + break + current = current_col.baseColumn + else: + break + + return errors + + class DatasetDefinition(BaseModel): - """Defines the complete specification for a single synthetic data table. + """Defines the complete specification for a single synthetic dataset. This class encapsulates all the information needed to generate a table of synthetic data, including the number of rows, partitioning, and column specifications. @@ -82,60 +251,22 @@ class DatagenSpec(BaseModel): generator_options: dict[str, Any] | None = None intended_for_databricks: bool | None = None # May be infered. - def _check_circular_dependencies( - self, - table_name: str, - columns: list[ColumnDefinition] - ) -> list[str]: - """Check for circular dependencies in baseColumn references within a table. - - Analyzes column dependencies to detect cycles where columns reference each other - in a circular manner (e.g., col A depends on col B, col B depends on col A). - Such circular dependencies would make data generation impossible. - - :param table_name: Name of the table being validated (used in error messages) - :param columns: List of ColumnDefinition objects to check for circular dependencies - :returns: List of error message strings describing any circular dependencies found. - Empty list if no circular dependencies exist + def _validate_generator_options(self, result: ValidationResult) -> None: + """Validate generator options against known valid options. - .. note:: - This method performs a graph traversal to detect cycles in the dependency chain + :param result: ValidationResult to collect errors/warnings """ - errors = [] - column_map = {col.name: col for col in columns} - - for col in columns: - if col.baseColumn and col.baseColumn != "id": - # Track the dependency chain - visited: set[str] = set() - current = col.name - - while current: - if current in visited: - # Found a cycle - cycle_path = " -> ".join([*list(visited), current]) - errors.append( - f"Table '{table_name}': Circular dependency detected in column '{col.name}': {cycle_path}" - ) - break - - visited.add(current) - current_col = column_map.get(current) - - if not current_col: - break - - # Move to the next column in the chain - if current_col.baseColumn and current_col.baseColumn != "id": - if current_col.baseColumn not in column_map: - # baseColumn doesn't exist - we'll catch this in another validation - break - current = current_col.baseColumn - else: - # Reached a column that doesn't have a baseColumn or uses "id" - break - - return errors + if self.generator_options: + known_options = [ + "random", "randomSeed", "randomSeedMethod", "verbose", + "debug", "seedColumnName" + ] + for key in self.generator_options: + if key not in known_options: + result.add_warning( + f"Unknown generator option: '{key}'. " + "This may be ignored during generation." + ) def validate(self, strict: bool = True) -> ValidationResult: # type: ignore[override] """Validate the entire DatagenSpec configuration comprehensively. @@ -156,6 +287,7 @@ def validate(self, strict: bool = True) -> ValidationResult: # type: ignore[ove :param strict: Controls validation failure behavior: - If True: Raises ValueError for any errors OR warnings found - If False: Only raises ValueError for errors (warnings are tolerated) + Defaults to True :returns: ValidationResult object containing all collected errors and warnings, even if an exception is raised :raises ValueError: If validation fails based on strict mode setting. @@ -176,66 +308,26 @@ def validate(self, strict: bool = True) -> ValidationResult: # type: ignore[ove # 2. Validate each table (continue checking all tables even if errors found) for table_name, table_def in self.datasets.items(): - # Check table has at least one column - if not table_def.columns: - result.add_error(f"Table '{table_name}' must have at least one column") - continue # Skip further checks for this table since it has no columns + # Validate basic properties (returns False if no columns, skip further checks) + if not _validate_table_basic_properties(table_name, table_def, result): + continue - # Check row count is positive - if table_def.number_of_rows <= 0: - result.add_error( - f"Table '{table_name}' has invalid number_of_rows: {table_def.number_of_rows}. " - "Must be a positive integer." - ) + # Validate duplicate columns + _validate_duplicate_columns(table_name, table_def, result) - # Check partitions if specified - # Can we find a way to use the default way? - if table_def.partitions is not None and table_def.partitions <= 0: - result.add_error( - f"Table '{table_name}' has invalid partitions: {table_def.partitions}. " - "Must be a positive integer or None." - ) - - # Check for duplicate column names - column_names = [col.name for col in table_def.columns] - duplicates = [name for name in set(column_names) if column_names.count(name) > 1] - if duplicates: - result.add_error( - f"Table '{table_name}' has duplicate column names: {', '.join(duplicates)}" - ) - - # Build column map for reference checking - column_map = {col.name: col for col in table_def.columns} - for col in table_def.columns: - if col.baseColumn and col.baseColumn != "id": - if col.baseColumn not in column_map: - result.add_error( - f"Table '{table_name}', column '{col.name}': " - f"baseColumn '{col.baseColumn}' does not exist in the table" - ) + # Validate column references + _validate_column_references(table_name, table_def, result) # Check for circular dependencies in baseColumn references - circular_errors = self._check_circular_dependencies(table_name, table_def.columns) + circular_errors = _check_circular_dependencies(table_name, table_def.columns) for error in circular_errors: result.add_error(error) - # Check primary key constraints - primary_columns = [col for col in table_def.columns if col.primary] - if len(primary_columns) > 1: - primary_names = [col.name for col in primary_columns] - result.add_warning( - f"Table '{table_name}' has multiple primary columns: {', '.join(primary_names)}. " - "This may not be the intended behavior." - ) + # Validate primary key constraints + _validate_primary_key_columns(table_name, table_def, result) - # Check for columns with no type and not using baseColumn properly - for col in table_def.columns: - if not col.primary and not col.type and not col.options: - result.add_warning( - f"Table '{table_name}', column '{col.name}': " - "No type specified and no options provided. " - "Column may not generate data as expected." - ) + # Validate column types + _validate_column_types(table_name, table_def, result) # 3. Check output destination if not self.output_destination: @@ -244,18 +336,8 @@ def validate(self, strict: bool = True) -> ValidationResult: # type: ignore[ove "Set output_destination to save generated data." ) - # 4. Validate generator options (if any known options) - if self.generator_options: - known_options = [ - "random", "randomSeed", "randomSeedMethod", "verbose", - "debug", "seedColumnName" - ] - for key in self.generator_options: - if key not in known_options: - result.add_warning( - f"Unknown generator option: '{key}'. " - "This may be ignored during generation." - ) + # 4. Validate generator options + self._validate_generator_options(result) # Now that all validations are complete, decide whether to raise if (strict and (result.errors or result.warnings)) or (not strict and result.errors): diff --git a/dbldatagen/spec/generator_spec_impl.py b/dbldatagen/spec/generator_spec_impl.py index ac42204f..c0ceb191 100644 --- a/dbldatagen/spec/generator_spec_impl.py +++ b/dbldatagen/spec/generator_spec_impl.py @@ -20,6 +20,68 @@ INTERNAL_ID_COLUMN_NAME = "id" +def _columnSpecToDatagenColumnSpec(col_def: ColumnDefinition) -> dict[str, Any]: + """Convert a ColumnDefinition spec into dbldatagen DataGenerator column arguments. + + This function translates the declarative ColumnDefinition format into the + keyword arguments expected by dbldatagen's withColumn() method. It handles special + cases like primary keys, nullable columns, and omitted columns. + + Primary key columns receive special treatment: + - Automatically use the internal ID column as their base + - String primary keys use hash-based generation + - Numeric primary keys maintain sequential values + + :param col_def: ColumnDefinition object from a DatagenSpec + :returns: Dictionary of keyword arguments suitable for DataGenerator.withColumn() + + .. note:: + Conflicting options for primary keys (like min/max, values, expr) will generate + warnings but won't prevent generation - the primary key behavior takes precedence + """ + col_name = col_def.name + col_type = col_def.type + kwargs = col_def.options.copy() if col_def.options is not None else {} + + if col_def.primary: + kwargs["colType"] = col_type + kwargs["baseColumn"] = INTERNAL_ID_COLUMN_NAME + + if col_type == "string": + kwargs["baseColumnType"] = "hash" + elif col_type not in ["int", "long", "integer", "bigint", "short"]: + kwargs["baseColumnType"] = "auto" + logger.warning( + f"Primary key '{col_name}' has non-standard type '{col_type}'") + + # Log conflicting options for primary keys + conflicting_opts_for_pk = [ + "distribution", "template", "dataRange", "random", "omit", + "min", "max", "uniqueValues", "values", "expr" + ] + + for opt_key in conflicting_opts_for_pk: + if opt_key in kwargs: + logger.warning( + f"Primary key '{col_name}': Option '{opt_key}' may be ignored") + + if col_def.omit is not None and col_def.omit: + kwargs["omit"] = True + else: + kwargs = col_def.options.copy() if col_def.options is not None else {} + + if col_type: + kwargs["colType"] = col_type + if col_def.baseColumn: + kwargs["baseColumn"] = col_def.baseColumn + if col_def.baseColumnType: + kwargs["baseColumnType"] = col_def.baseColumnType + if col_def.omit is not None: + kwargs["omit"] = col_def.omit + + return kwargs + + class Generator: """Main orchestrator for generating synthetic data from DatagenSpec configurations. @@ -59,74 +121,9 @@ def __init__(self, spark: SparkSession, app_name: str = "DataGen_ClassBased") -> "SparkSession cannot be None during Generator initialization") raise RuntimeError("SparkSession cannot be None") self.spark = spark - self._created_spark_session = False self.app_name = app_name logger.info("Generator initialized with SparkSession") - def _columnSpecToDatagenColumnSpec(self, col_def: ColumnDefinition) -> dict[str, Any]: - """Convert a ColumnDefinition spec into dbldatagen DataGenerator column arguments. - - This internal method translates the declarative ColumnDefinition format into the - keyword arguments expected by dbldatagen's withColumn() method. It handles special - cases like primary keys, nullable columns, and omitted columns. - - Primary key columns receive special treatment: - - Automatically use the internal ID column as their base - - String primary keys use hash-based generation - - Numeric primary keys maintain sequential values - - :param col_def: ColumnDefinition object from a DatagenSpec - :returns: Dictionary of keyword arguments suitable for DataGenerator.withColumn() - - .. note:: - This is an internal method not intended for direct use by end users - - .. note:: - Conflicting options for primary keys (like min/max, values, expr) will generate - warnings but won't prevent generation - the primary key behavior takes precedence - """ - col_name = col_def.name - col_type = col_def.type - kwargs = col_def.options.copy() if col_def.options is not None else {} - - if col_def.primary: - kwargs["colType"] = col_type - kwargs["baseColumn"] = INTERNAL_ID_COLUMN_NAME - - if col_type == "string": - kwargs["baseColumnType"] = "hash" - elif col_type not in ["int", "long", "integer", "bigint", "short"]: - kwargs["baseColumnType"] = "auto" - logger.warning( - f"Primary key '{col_name}' has non-standard type '{col_type}'") - - # Log conflicting options for primary keys - conflicting_opts_for_pk = [ - "distribution", "template", "dataRange", "random", "omit", - "min", "max", "uniqueValues", "values", "expr" - ] - - for opt_key in conflicting_opts_for_pk: - if opt_key in kwargs: - logger.warning( - f"Primary key '{col_name}': Option '{opt_key}' may be ignored") - - if col_def.omit is not None and col_def.omit: - kwargs["omit"] = True - else: - kwargs = col_def.options.copy() if col_def.options is not None else {} - - if col_type: - kwargs["colType"] = col_type - if col_def.baseColumn: - kwargs["baseColumn"] = col_def.baseColumn - if col_def.baseColumnType: - kwargs["baseColumnType"] = col_def.baseColumnType - if col_def.omit is not None: - kwargs["omit"] = col_def.omit - - return kwargs - def _prepareDataGenerators( self, config: DatagenSpec, @@ -191,7 +188,7 @@ def _prepareDataGenerators( # Process each column for col_def in table_spec.columns: - kwargs = self._columnSpecToDatagenColumnSpec(col_def) + kwargs = _columnSpecToDatagenColumnSpec(col_def) data_gen = data_gen.withColumn(colName=col_def.name, **kwargs) # Has performance implications. @@ -206,9 +203,9 @@ def _prepareDataGenerators( logger.info("All data generators prepared successfully") return prepared_generators - def writePreparedData( - self, - prepared_generators: dict[str, dg.DataGenerator], + @staticmethod + def _writePreparedData( + prepared_generators: dict[str, dg.DataGenerator], output_destination: Union[UCSchemaTarget, FilePathTarget, None], config_source_name: str = "PydanticConfig", ) -> None: @@ -244,8 +241,8 @@ def writePreparedData( logger.info("Starting data writing phase") if not prepared_generators: - logger.warning("No prepared data generators to write") - return + logger.error("No prepared data generators to write") + raise RuntimeError("No prepared data generators to write") for table_name, data_gen in prepared_generators.items(): logger.info(f"Writing table: {table_name}") @@ -315,7 +312,7 @@ def generateAndWriteData( Example: >>> spec = DatagenSpec( - ... tables={"users": user_table_def}, + ... datasets={"users": user_table_def}, ... output_destination=UCSchemaTarget(catalog="main", schema_="test") ... ) >>> spec.validate() # Check for errors first @@ -334,7 +331,7 @@ def generateAndWriteData( return # Phase 2: Write data - self.writePreparedData( + self._writePreparedData( prepared_generators_map, config.output_destination, config_source_name diff --git a/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py b/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py index 1095d270..72714ba6 100644 --- a/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py +++ b/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py @@ -4,7 +4,7 @@ the basic stock ticker dataset, corresponding to the BasicStockTickerProvider. """ -from random import random +import random from dbldatagen.spec.generator_spec import DatagenSpec, DatasetDefinition from dbldatagen.spec.column_spec import ColumnDefinition @@ -47,9 +47,9 @@ def create_basic_stock_ticker_spec( # Generate random values for start_value, growth_rate, and volatility # These need to be pre-computed for the values option num_value_sets = max(1, int(num_symbols / 10)) - start_values = [1.0 + 199.0 * random() for _ in range(num_value_sets)] - growth_rates = [-0.1 + 0.35 * random() for _ in range(num_value_sets)] - volatility_values = [0.0075 * random() for _ in range(num_value_sets)] + start_values = [1.0 + 199.0 * random.random() for _ in range(num_value_sets)] + growth_rates = [-0.1 + 0.35 * random.random() for _ in range(num_value_sets)] + volatility_values = [0.0075 * random.random() for _ in range(num_value_sets)] columns = [ # Symbol ID (numeric identifier for symbol) diff --git a/tests/test_datagen_specs.py b/tests/test_datagen_specs.py index 589d5e6b..215041ee 100644 --- a/tests/test_datagen_specs.py +++ b/tests/test_datagen_specs.py @@ -37,7 +37,7 @@ def test_basic_user_spec_creation(self): ) spec = DatagenSpec( - tables={"users": table_def}, + datasets={"users": table_def}, output_destination=None ) @@ -68,7 +68,7 @@ def test_basic_user_spec_validation(self): ) spec = DatagenSpec( - tables={"users": table_def} + datasets={"users": table_def} ) validation_result = spec.validate(strict=False) @@ -98,7 +98,7 @@ def test_column_with_base_column(self): ) spec = DatagenSpec( - tables={"symbols": table_def} + datasets={"symbols": table_def} ) validation_result = spec.validate(strict=False) @@ -145,7 +145,7 @@ def test_basic_stock_ticker_spec_creation(self): ) spec = DatagenSpec( - tables={"stock_tickers": table_def}, + datasets={"stock_tickers": table_def}, output_destination=None ) @@ -181,7 +181,7 @@ def test_stock_ticker_with_omitted_columns(self): ) spec = DatagenSpec( - tables={"prices": table_def} + datasets={"prices": table_def} ) validation_result = spec.validate(strict=False) @@ -198,7 +198,7 @@ class TestDatagenSpecValidation(unittest.TestCase): def test_empty_tables_validation(self): """Test that spec with no tables fails validation.""" - spec = DatagenSpec(tables={}) + spec = DatagenSpec(datasets={}) with self.assertRaises(ValueError) as context: spec.validate(strict=False) @@ -218,7 +218,7 @@ def test_duplicate_column_names(self): columns=columns ) - spec = DatagenSpec(tables={"test": table_def}) + spec = DatagenSpec(datasets={"test": table_def}) with self.assertRaises(ValueError) as context: spec.validate(strict=False) @@ -240,7 +240,7 @@ def test_negative_rows_validation(self): columns=columns ) - spec = DatagenSpec(tables={"test": table_def}) + spec = DatagenSpec(datasets={"test": table_def}) with self.assertRaises(ValueError) as context: spec.validate(strict=False) @@ -261,7 +261,7 @@ def test_spec_with_generator_options(self): ) spec = DatagenSpec( - tables={"test": table_def}, + datasets={"test": table_def}, generator_options={ "randomSeedMethod": "hash_fieldname", "verbose": True diff --git a/tests/test_specs.py b/tests/test_specs.py index 4e994c5f..a44f06c6 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -101,7 +101,7 @@ class TestDatagenSpecValidation: def test_valid_spec_passes_validation(self): spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=100, columns=[ @@ -118,14 +118,14 @@ def test_valid_spec_passes_validation(self): assert len(result.errors) == 0 def test_empty_tables_raises_error(self): - spec = DatagenSpec(tables={}) + spec = DatagenSpec(datasets={}) with pytest.raises(ValueError, match="at least one table"): spec.validate(strict=True) def test_table_without_columns_raises_error(self): spec = DatagenSpec( - tables={ + datasets={ "empty_table": DatasetDefinition( number_of_rows=100, columns=[] @@ -138,7 +138,7 @@ def test_table_without_columns_raises_error(self): def test_negative_row_count_raises_error(self): spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=-10, columns=[ColumnDefinition(name="id", type="int", primary=True)] @@ -151,7 +151,7 @@ def test_negative_row_count_raises_error(self): def test_zero_row_count_raises_error(self): spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=0, columns=[ColumnDefinition(name="id", type="int", primary=True)] @@ -164,7 +164,7 @@ def test_zero_row_count_raises_error(self): def test_invalid_partitions_raises_error(self): spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=100, partitions=-5, @@ -178,7 +178,7 @@ def test_invalid_partitions_raises_error(self): def test_duplicate_column_names_raises_error(self): spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=100, columns=[ @@ -195,7 +195,7 @@ def test_duplicate_column_names_raises_error(self): def test_invalid_base_column_reference_raises_error(self): spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=100, columns=[ @@ -211,7 +211,7 @@ def test_invalid_base_column_reference_raises_error(self): def test_circular_dependency_raises_error(self): spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=100, columns=[ @@ -229,7 +229,7 @@ def test_circular_dependency_raises_error(self): def test_multiple_primary_columns_warning(self): spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=100, columns=[ @@ -252,7 +252,7 @@ def test_multiple_primary_columns_warning(self): def test_column_without_type_or_options_warning(self): spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=100, columns=[ @@ -270,7 +270,7 @@ def test_column_without_type_or_options_warning(self): def test_no_output_destination_warning(self): spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=100, columns=[ColumnDefinition(name="id", type="int", primary=True)] @@ -285,7 +285,7 @@ def test_no_output_destination_warning(self): def test_unknown_generator_option_warning(self): spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=100, columns=[ColumnDefinition(name="id", type="int", primary=True)] @@ -302,7 +302,7 @@ def test_unknown_generator_option_warning(self): def test_multiple_errors_collected(self): """Test that all errors are collected before raising""" spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=-10, # Error 1 partitions=0, # Error 2 @@ -327,7 +327,7 @@ def test_multiple_errors_collected(self): def test_strict_mode_raises_on_warnings(self): spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=100, columns=[ColumnDefinition(name="id", type="int", primary=True)] @@ -347,7 +347,7 @@ def test_strict_mode_raises_on_warnings(self): def test_valid_base_column_chain(self): """Test that valid baseColumn chains work""" spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=100, columns=[ @@ -366,7 +366,7 @@ def test_valid_base_column_chain(self): def test_multiple_tables_validation(self): """Test validation across multiple tables""" spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=100, columns=[ColumnDefinition(name="id", type="int", primary=True)] @@ -423,7 +423,7 @@ class TestValidationIntegration: def test_realistic_valid_spec(self): """Test a realistic, valid specification""" spec = DatagenSpec( - tables={ + datasets={ "users": DatasetDefinition( number_of_rows=1000, partitions=4, diff --git a/tests/test_standard_datasets.py b/tests/test_standard_datasets.py index e8272b4f..90afac04 100644 --- a/tests/test_standard_datasets.py +++ b/tests/test_standard_datasets.py @@ -190,7 +190,7 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=1000000, parti ds_definition = X1.getDatasetDefinition() print("ds_definition", ds_definition) assert ds_definition.name == "providers/X1" - assert ds_definition.datasets == [DatasetProvider.DEFAULT_TABLE_NAME] + assert ds_definition.tables == [DatasetProvider.DEFAULT_TABLE_NAME] assert ds_definition.primaryTable == DatasetProvider.DEFAULT_TABLE_NAME assert ds_definition.summary is not None assert ds_definition.description is not None From 7da16dd715703feb05d039331c3b42ac8ad95eb1 Mon Sep 17 00:00:00 2001 From: Anup Kalburgi Date: Thu, 11 Dec 2025 11:18:46 -0500 Subject: [PATCH 19/20] renaming tables to datasets in docs and test, and black fmt changes --- dbldatagen/__init__.py | 1 + dbldatagen/spec/__init__.py | 3 + dbldatagen/spec/column_spec.py | 1 + dbldatagen/spec/generator_spec.py | 58 ++---- dbldatagen/spec/generator_spec_impl.py | 70 +++----- dbldatagen/spec/output_targets.py | 5 +- dbldatagen/text_generator_plugins.py | 4 - .../basic_stock_ticker_datagen_spec.py | 2 +- .../basic_user_datagen_spec.py | 2 +- makefile | 2 +- pyproject.toml | 11 ++ tests/test_datagen_specs.py | 168 ++++-------------- tests/test_datasets_with_specs.py | 21 +-- tests/test_specs.py | 124 ++++--------- 14 files changed, 143 insertions(+), 329 deletions(-) diff --git a/dbldatagen/__init__.py b/dbldatagen/__init__.py index 2104a35a..eaedeec0 100644 --- a/dbldatagen/__init__.py +++ b/dbldatagen/__init__.py @@ -97,6 +97,7 @@ "datagen_types", ] + def python_version_check(python_version_expected): """Check against Python version diff --git a/dbldatagen/spec/__init__.py b/dbldatagen/spec/__init__.py index a22fb217..b8eed5a3 100644 --- a/dbldatagen/spec/__init__.py +++ b/dbldatagen/spec/__init__.py @@ -39,11 +39,14 @@ def __getattr__(name: str) -> Any: # noqa: ANN401 """ if name == "ColumnSpec": from .column_spec import ColumnDefinition # noqa: PLC0415 + return ColumnDefinition elif name == "GeneratorSpec": from .generator_spec import DatagenSpec # noqa: PLC0415 + return DatagenSpec elif name == "GeneratorSpecImpl": from .generator_spec_impl import Generator # noqa: PLC0415 + return Generator raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/dbldatagen/spec/column_spec.py b/dbldatagen/spec/column_spec.py index 3d81c462..dc121ce1 100644 --- a/dbldatagen/spec/column_spec.py +++ b/dbldatagen/spec/column_spec.py @@ -42,6 +42,7 @@ class ColumnDefinition(BaseModel): Columns can be chained via baseColumn references, but circular dependencies will be caught during validation """ + name: str type: DbldatagenBasicType | None = None primary: bool = False diff --git a/dbldatagen/spec/generator_spec.py b/dbldatagen/spec/generator_spec.py index cfcba822..63997aa0 100644 --- a/dbldatagen/spec/generator_spec.py +++ b/dbldatagen/spec/generator_spec.py @@ -16,11 +16,7 @@ logger = logging.getLogger(__name__) -def _validate_table_basic_properties( - table_name: str, - table_def: DatasetDefinition, - result: ValidationResult -) -> bool: +def _validate_table_basic_properties(table_name: str, table_def: DatasetDefinition, result: ValidationResult) -> bool: """Validate basic table properties like columns, row count, and partitions. :param table_name: Name of the table being validated @@ -51,11 +47,7 @@ def _validate_table_basic_properties( return True -def _validate_duplicate_columns( - table_name: str, - table_def: DatasetDefinition, - result: ValidationResult -) -> None: +def _validate_duplicate_columns(table_name: str, table_def: DatasetDefinition, result: ValidationResult) -> None: """Check for duplicate column names within a table. :param table_name: Name of the table being validated @@ -65,16 +57,10 @@ def _validate_duplicate_columns( column_names = [col.name for col in table_def.columns] duplicates = [name for name in set(column_names) if column_names.count(name) > 1] if duplicates: - result.add_error( - f"Table '{table_name}' has duplicate column names: {', '.join(duplicates)}" - ) + result.add_error(f"Table '{table_name}' has duplicate column names: {', '.join(duplicates)}") -def _validate_column_references( - table_name: str, - table_def: DatasetDefinition, - result: ValidationResult -) -> None: +def _validate_column_references(table_name: str, table_def: DatasetDefinition, result: ValidationResult) -> None: """Validate that baseColumn references point to existing columns. :param table_name: Name of the table being validated @@ -91,11 +77,7 @@ def _validate_column_references( ) -def _validate_primary_key_columns( - table_name: str, - table_def: DatasetDefinition, - result: ValidationResult -) -> None: +def _validate_primary_key_columns(table_name: str, table_def: DatasetDefinition, result: ValidationResult) -> None: """Validate primary key column constraints. :param table_name: Name of the table being validated @@ -111,11 +93,7 @@ def _validate_primary_key_columns( ) -def _validate_column_types( - table_name: str, - table_def: DatasetDefinition, - result: ValidationResult -) -> None: +def _validate_column_types(table_name: str, table_def: DatasetDefinition, result: ValidationResult) -> None: """Validate column type specifications. :param table_name: Name of the table being validated @@ -131,10 +109,7 @@ def _validate_column_types( ) -def _check_circular_dependencies( - table_name: str, - columns: list[ColumnDefinition] -) -> list[str]: +def _check_circular_dependencies(table_name: str, columns: list[ColumnDefinition]) -> list[str]: """Check for circular dependencies in baseColumn references within a table. Analyzes column dependencies to detect cycles where columns reference each other @@ -209,6 +184,7 @@ class DatasetDefinition(BaseModel): .. note:: Column order in the list determines the order of columns in the generated output """ + number_of_rows: int partitions: int | None = None columns: list[ColumnDefinition] @@ -246,10 +222,13 @@ class DatagenSpec(BaseModel): Multiple tables can share the same DatagenSpec and will be generated in the order they appear in the tables dictionary """ + datasets: dict[str, DatasetDefinition] - output_destination: Union[UCSchemaTarget, FilePathTarget] | None = None # there is a abstraction, may be we can use that? talk to Greg + output_destination: Union[UCSchemaTarget, FilePathTarget] | None = ( + None # there is a abstraction, may be we can use that? talk to Greg + ) generator_options: dict[str, Any] | None = None - intended_for_databricks: bool | None = None # May be infered. + intended_for_databricks: bool | None = None # May be infered. def _validate_generator_options(self, result: ValidationResult) -> None: """Validate generator options against known valid options. @@ -257,16 +236,10 @@ def _validate_generator_options(self, result: ValidationResult) -> None: :param result: ValidationResult to collect errors/warnings """ if self.generator_options: - known_options = [ - "random", "randomSeed", "randomSeedMethod", "verbose", - "debug", "seedColumnName" - ] + known_options = ["random", "randomSeed", "randomSeedMethod", "verbose", "debug", "seedColumnName"] for key in self.generator_options: if key not in known_options: - result.add_warning( - f"Unknown generator option: '{key}'. " - "This may be ignored during generation." - ) + result.add_warning(f"Unknown generator option: '{key}'. " "This may be ignored during generation.") def validate(self, strict: bool = True) -> ValidationResult: # type: ignore[override] """Validate the entire DatagenSpec configuration comprehensively. @@ -345,7 +318,6 @@ def validate(self, strict: bool = True) -> ValidationResult: # type: ignore[ove return result - def display_all_tables(self) -> None: """Display a formatted view of all table definitions in the spec. diff --git a/dbldatagen/spec/generator_spec_impl.py b/dbldatagen/spec/generator_spec_impl.py index c0ceb191..d7953f95 100644 --- a/dbldatagen/spec/generator_spec_impl.py +++ b/dbldatagen/spec/generator_spec_impl.py @@ -51,19 +51,25 @@ def _columnSpecToDatagenColumnSpec(col_def: ColumnDefinition) -> dict[str, Any]: kwargs["baseColumnType"] = "hash" elif col_type not in ["int", "long", "integer", "bigint", "short"]: kwargs["baseColumnType"] = "auto" - logger.warning( - f"Primary key '{col_name}' has non-standard type '{col_type}'") + logger.warning(f"Primary key '{col_name}' has non-standard type '{col_type}'") # Log conflicting options for primary keys conflicting_opts_for_pk = [ - "distribution", "template", "dataRange", "random", "omit", - "min", "max", "uniqueValues", "values", "expr" + "distribution", + "template", + "dataRange", + "random", + "omit", + "min", + "max", + "uniqueValues", + "values", + "expr", ] for opt_key in conflicting_opts_for_pk: if opt_key in kwargs: - logger.warning( - f"Primary key '{col_name}': Option '{opt_key}' may be ignored") + logger.warning(f"Primary key '{col_name}': Option '{opt_key}' may be ignored") if col_def.omit is not None and col_def.omit: kwargs["omit"] = True @@ -117,17 +123,14 @@ def __init__(self, spark: SparkSession, app_name: str = "DataGen_ClassBased") -> :raises RuntimeError: If spark is None or not properly initialized """ if not spark: - logger.error( - "SparkSession cannot be None during Generator initialization") + logger.error("SparkSession cannot be None during Generator initialization") raise RuntimeError("SparkSession cannot be None") self.spark = spark self.app_name = app_name logger.info("Generator initialized with SparkSession") def _prepareDataGenerators( - self, - config: DatagenSpec, - config_source_name: str = "PydanticConfig" + self, config: DatagenSpec, config_source_name: str = "PydanticConfig" ) -> dict[str, dg.DataGenerator]: """Prepare DataGenerator objects for all tables defined in the spec. @@ -157,20 +160,17 @@ def _prepareDataGenerators( Preparation is separate from building to allow inspection and modification of DataGenerators before data generation begins """ - logger.info( - f"Preparing data generators for {len(config.datasets)} tables") + logger.info(f"Preparing data generators for {len(config.datasets)} tables") if not self.spark: - logger.error( - "SparkSession is not available. Cannot prepare data generators") - raise RuntimeError( - "SparkSession is not available. Cannot prepare data generators") + logger.error("SparkSession is not available. Cannot prepare data generators") + raise RuntimeError("SparkSession is not available. Cannot prepare data generators") tables_config: dict[str, DatasetDefinition] = config.datasets global_gen_options = config.generator_options if config.generator_options else {} prepared_generators: dict[str, dg.DataGenerator] = {} - generation_order = list(tables_config.keys()) # This becomes important when we get into multitable + generation_order = list(tables_config.keys()) # This becomes important when we get into multitable for table_name in generation_order: table_spec = tables_config[table_name] @@ -190,22 +190,19 @@ def _prepareDataGenerators( for col_def in table_spec.columns: kwargs = _columnSpecToDatagenColumnSpec(col_def) data_gen = data_gen.withColumn(colName=col_def.name, **kwargs) - # Has performance implications. - prepared_generators[table_name] = data_gen logger.info(f"Successfully prepared table: {table_name}") except Exception as e: logger.error(f"Failed to prepare table '{table_name}': {e}") - raise RuntimeError( - f"Failed to prepare table '{table_name}': {e}") from e + raise RuntimeError(f"Failed to prepare table '{table_name}': {e}") from e logger.info("All data generators prepared successfully") return prepared_generators @staticmethod def _writePreparedData( - prepared_generators: dict[str, dg.DataGenerator], + prepared_generators: dict[str, dg.DataGenerator], output_destination: Union[UCSchemaTarget, FilePathTarget, None], config_source_name: str = "PydanticConfig", ) -> None: @@ -252,7 +249,8 @@ def _writePreparedData( requested_rows = data_gen.rowCount actual_row_count = df.count() logger.info( - f"Built DataFrame for '{table_name}': {actual_row_count} rows (requested: {requested_rows})") + f"Built DataFrame for '{table_name}': {actual_row_count} rows (requested: {requested_rows})" + ) if actual_row_count == 0 and requested_rows is not None and requested_rows > 0: logger.warning(f"Table '{table_name}': Requested {requested_rows} rows but built 0") @@ -275,11 +273,7 @@ def _writePreparedData( raise RuntimeError(f"Failed to write table '{table_name}': {e}") from e logger.info("All data writes completed successfully") - def generateAndWriteData( - self, - config: DatagenSpec, - config_source_name: str = "PydanticConfig" - ) -> None: + def generateAndWriteData(self, config: DatagenSpec, config_source_name: str = "PydanticConfig") -> None: """Execute the complete data generation workflow from spec to output. This is the primary high-level method for generating data from a DatagenSpec. It @@ -326,22 +320,14 @@ def generateAndWriteData( prepared_generators_map = self._prepareDataGenerators(config, config_source_name) if not prepared_generators_map and list(config.datasets.keys()): - logger.warning( - "No data generators were successfully prepared, though tables were defined") + logger.warning("No data generators were successfully prepared, though tables were defined") return # Phase 2: Write data - self._writePreparedData( - prepared_generators_map, - config.output_destination, - config_source_name - ) + self._writePreparedData(prepared_generators_map, config.output_destination, config_source_name) - logger.info( - "Combined data generation and writing completed successfully") + logger.info("Combined data generation and writing completed successfully") except Exception as e: - logger.error( - f"Error during combined data generation and writing: {e}") - raise RuntimeError( - f"Error during combined data generation and writing: {e}") from e + logger.error(f"Error during combined data generation and writing: {e}") + raise RuntimeError(f"Error during combined data generation and writing: {e}") from e diff --git a/dbldatagen/spec/output_targets.py b/dbldatagen/spec/output_targets.py index b403304a..bb9b1400 100644 --- a/dbldatagen/spec/output_targets.py +++ b/dbldatagen/spec/output_targets.py @@ -25,6 +25,7 @@ class UCSchemaTarget(BaseModel): .. note:: Tables will be written to the location: `{catalog}.{schema_}.{table_name}` """ + catalog: str schema_: str output_format: str = "delta" # Default to delta for UC Schema @@ -47,8 +48,7 @@ def validate_identifiers(cls, v: str) -> str: if not v.strip(): raise ValueError("Identifier must be non-empty.") if not v.isidentifier(): - logger.warning( - f"'{v}' is not a basic Python identifier. Ensure validity for Unity Catalog.") + logger.warning(f"'{v}' is not a basic Python identifier. Ensure validity for Unity Catalog.") return v.strip() def __str__(self) -> str: @@ -77,6 +77,7 @@ class FilePathTarget(BaseModel): The base_path can be a local file system path, DBFS path, or cloud storage path (e.g., s3://, gs://, abfs://) depending on your environment """ + base_path: str output_format: Literal["csv", "parquet"] # No default, must be specified diff --git a/dbldatagen/text_generator_plugins.py b/dbldatagen/text_generator_plugins.py index 63422090..ba4759df 100644 --- a/dbldatagen/text_generator_plugins.py +++ b/dbldatagen/text_generator_plugins.py @@ -218,10 +218,6 @@ def initFaker(ctx): rootProperty="faker", name="FakerText")) """ - _name: str - _initPerBatch: bool - _initFn: Callable | None - _rootProperty: object | None _name: str _initPerBatch: bool diff --git a/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py b/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py index 72714ba6..944d6eb7 100644 --- a/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py +++ b/examples/datagen_from_specs/basic_stock_ticker_datagen_spec.py @@ -273,7 +273,7 @@ def create_basic_stock_ticker_spec( ) spec = DatagenSpec( - tables={"stock_tickers": table_def}, + datasets={"stock_tickers": table_def}, output_destination=None, # No automatic persistence generator_options={ "randomSeedMethod": "hash_fieldname" diff --git a/examples/datagen_from_specs/basic_user_datagen_spec.py b/examples/datagen_from_specs/basic_user_datagen_spec.py index 450f9b9e..2e58bf5a 100644 --- a/examples/datagen_from_specs/basic_user_datagen_spec.py +++ b/examples/datagen_from_specs/basic_user_datagen_spec.py @@ -84,7 +84,7 @@ def create_basic_user_spec( ) spec = DatagenSpec( - tables={"users": table_def}, + datasets={"users": table_def}, output_destination=None, # No automatic persistence generator_options={ "randomSeedMethod": "hash_fieldname" diff --git a/makefile b/makefile index a551f964..e8b81031 100644 --- a/makefile +++ b/makefile @@ -3,7 +3,7 @@ all: clean dev lint fmt test clean: - rm -fr clean htmlcov .mypy_cache .pytest_cache .ruff_cache .coverage coverage.xml + rm -fr .venv clean htmlcov .mypy_cache .pytest_cache .ruff_cache .coverage coverage.xml rm -fr **/*.pyc dev: diff --git a/pyproject.toml b/pyproject.toml index c2a3ea41..955e668e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,17 @@ verify = ["black --check .", "mypy .", "pylint --output-format=colorized -j 0 dbldatagen tests"] + +[tool.hatch.envs.test-pydantic] +template = "default" +matrix = [ + { pydantic_version = ["1.10.6", "2.8.2"] } +] +extra-dependencies = [ + "pydantic=={matrix:pydantic_version}" +] + + [tool.black] target-version = ["py310"] line-length = 120 diff --git a/tests/test_datagen_specs.py b/tests/test_datagen_specs.py index 215041ee..f318f87b 100644 --- a/tests/test_datagen_specs.py +++ b/tests/test_datagen_specs.py @@ -13,33 +13,14 @@ class TestBasicUserDatagenSpec(unittest.TestCase): def test_basic_user_spec_creation(self): """Test creating a basic user DatagenSpec.""" columns = [ - ColumnDefinition( - name="customer_id", - type="long", - options={"minValue": 1000000, "maxValue": 9999999999} - ), - ColumnDefinition( - name="name", - type="string", - options={"template": r"\w \w"} - ), - ColumnDefinition( - name="email", - type="string", - options={"template": r"\w@\w.com"} - ), + ColumnDefinition(name="customer_id", type="long", options={"minValue": 1000000, "maxValue": 9999999999}), + ColumnDefinition(name="name", type="string", options={"template": r"\w \w"}), + ColumnDefinition(name="email", type="string", options={"template": r"\w@\w.com"}), ] - table_def = DatasetDefinition( - number_of_rows=1000, - partitions=2, - columns=columns - ) + table_def = DatasetDefinition(number_of_rows=1000, partitions=2, columns=columns) - spec = DatagenSpec( - datasets={"users": table_def}, - output_destination=None - ) + spec = DatagenSpec(datasets={"users": table_def}, output_destination=None) self.assertIsNotNone(spec) self.assertIn("users", spec.datasets) @@ -50,26 +31,13 @@ def test_basic_user_spec_creation(self): def test_basic_user_spec_validation(self): """Test validating a basic user DatagenSpec.""" columns = [ - ColumnDefinition( - name="customer_id", - type="long", - options={"minValue": 1000000} - ), - ColumnDefinition( - name="name", - type="string", - options={"template": r"\w \w"} - ), + ColumnDefinition(name="customer_id", type="long", options={"minValue": 1000000}), + ColumnDefinition(name="name", type="string", options={"template": r"\w \w"}), ] - table_def = DatasetDefinition( - number_of_rows=100, - columns=columns - ) + table_def = DatasetDefinition(number_of_rows=100, columns=columns) - spec = DatagenSpec( - datasets={"users": table_def} - ) + spec = DatagenSpec(datasets={"users": table_def}) validation_result = spec.validate(strict=False) self.assertTrue(validation_result.is_valid()) @@ -78,28 +46,13 @@ def test_basic_user_spec_validation(self): def test_column_with_base_column(self): """Test creating columns that depend on other columns.""" columns = [ - ColumnDefinition( - name="symbol_id", - type="long", - options={"minValue": 1, "maxValue": 100} - ), - ColumnDefinition( - name="symbol", - type="string", - options={ - "expr": "concat('SYM', symbol_id)" - } - ), + ColumnDefinition(name="symbol_id", type="long", options={"minValue": 1, "maxValue": 100}), + ColumnDefinition(name="symbol", type="string", options={"expr": "concat('SYM', symbol_id)"}), ] - table_def = DatasetDefinition( - number_of_rows=50, - columns=columns - ) + table_def = DatasetDefinition(number_of_rows=50, columns=columns) - spec = DatagenSpec( - datasets={"symbols": table_def} - ) + spec = DatagenSpec(datasets={"symbols": table_def}) validation_result = spec.validate(strict=False) self.assertTrue(validation_result.is_valid()) @@ -111,43 +64,18 @@ class TestBasicStockTickerDatagenSpec(unittest.TestCase): def test_basic_stock_ticker_spec_creation(self): """Test creating a basic stock ticker DatagenSpec.""" columns = [ + ColumnDefinition(name="symbol", type="string", options={"template": r"\u\u\u"}), ColumnDefinition( - name="symbol", - type="string", - options={"template": r"\u\u\u"} - ), - ColumnDefinition( - name="post_date", - type="date", - options={"expr": "date_add(cast('2024-10-01' as date), floor(id / 100))"} - ), - ColumnDefinition( - name="open", - type="decimal", - options={"minValue": 100.0, "maxValue": 500.0} - ), - ColumnDefinition( - name="close", - type="decimal", - options={"minValue": 100.0, "maxValue": 500.0} - ), - ColumnDefinition( - name="volume", - type="long", - options={"minValue": 100000, "maxValue": 5000000} + name="post_date", type="date", options={"expr": "date_add(cast('2024-10-01' as date), floor(id / 100))"} ), + ColumnDefinition(name="open", type="decimal", options={"minValue": 100.0, "maxValue": 500.0}), + ColumnDefinition(name="close", type="decimal", options={"minValue": 100.0, "maxValue": 500.0}), + ColumnDefinition(name="volume", type="long", options={"minValue": 100000, "maxValue": 5000000}), ] - table_def = DatasetDefinition( - number_of_rows=1000, - partitions=2, - columns=columns - ) + table_def = DatasetDefinition(number_of_rows=1000, partitions=2, columns=columns) - spec = DatagenSpec( - datasets={"stock_tickers": table_def}, - output_destination=None - ) + spec = DatagenSpec(datasets={"stock_tickers": table_def}, output_destination=None) self.assertIsNotNone(spec) self.assertIn("stock_tickers", spec.datasets) @@ -161,28 +89,15 @@ def test_stock_ticker_with_omitted_columns(self): name="base_price", type="decimal", options={"minValue": 100.0, "maxValue": 500.0}, - omit=True # Intermediate column - ), - ColumnDefinition( - name="open", - type="decimal", - options={"expr": "base_price * 0.99"} - ), - ColumnDefinition( - name="close", - type="decimal", - options={"expr": "base_price * 1.01"} + omit=True, # Intermediate column ), + ColumnDefinition(name="open", type="decimal", options={"expr": "base_price * 0.99"}), + ColumnDefinition(name="close", type="decimal", options={"expr": "base_price * 1.01"}), ] - table_def = DatasetDefinition( - number_of_rows=100, - columns=columns - ) + table_def = DatasetDefinition(number_of_rows=100, columns=columns) - spec = DatagenSpec( - datasets={"prices": table_def} - ) + spec = DatagenSpec(datasets={"prices": table_def}) validation_result = spec.validate(strict=False) self.assertTrue(validation_result.is_valid()) @@ -213,10 +128,7 @@ def test_duplicate_column_names(self): ColumnDefinition(name="id", type="string"), # Duplicate! ] - table_def = DatasetDefinition( - number_of_rows=100, - columns=columns - ) + table_def = DatasetDefinition(number_of_rows=100, columns=columns) spec = DatagenSpec(datasets={"test": table_def}) @@ -227,18 +139,12 @@ def test_duplicate_column_names(self): self.assertIn("duplicate column names", str(context.exception)) self.assertIn("id", str(context.exception)) - def test_negative_rows_validation(self): """Test that negative row counts fail validation.""" - columns = [ - ColumnDefinition(name="col1", type="long") - ] + columns = [ColumnDefinition(name="col1", type="long")] # Create with negative rows using dict to bypass Pydantic validation - table_def = DatasetDefinition( - number_of_rows=-100, # Invalid - columns=columns - ) + table_def = DatasetDefinition(number_of_rows=-100, columns=columns) # Invalid spec = DatagenSpec(datasets={"test": table_def}) @@ -251,21 +157,12 @@ def test_negative_rows_validation(self): def test_spec_with_generator_options(self): """Test creating spec with generator options.""" - columns = [ - ColumnDefinition(name="value", type="long") - ] + columns = [ColumnDefinition(name="value", type="long")] - table_def = DatasetDefinition( - number_of_rows=100, - columns=columns - ) + table_def = DatasetDefinition(number_of_rows=100, columns=columns) spec = DatagenSpec( - datasets={"test": table_def}, - generator_options={ - "randomSeedMethod": "hash_fieldname", - "verbose": True - } + datasets={"test": table_def}, generator_options={"randomSeedMethod": "hash_fieldname", "verbose": True} ) self.assertIsNotNone(spec.generator_options) @@ -275,6 +172,3 @@ def test_spec_with_generator_options(self): if __name__ == "__main__": unittest.main() - - - diff --git a/tests/test_datasets_with_specs.py b/tests/test_datasets_with_specs.py index 549260f5..661e847f 100644 --- a/tests/test_datasets_with_specs.py +++ b/tests/test_datasets_with_specs.py @@ -16,6 +16,7 @@ class TestBasicUserSpec(unittest.TestCase): def setUp(self): """Set up test fixtures - define model inline to avoid import issues.""" + # Define the model inline to avoid triggering Spark imports class BasicUser(BaseModel): customer_id: int = Field(..., ge=1000000) @@ -33,7 +34,7 @@ def test_valid_user_creation(self): name="John Doe", email="john.doe@example.com", ip_addr="192.168.1.100", - phone="(555)-123-4567" + phone="(555)-123-4567", ) self.assertEqual(user.customer_id, 1234567890) @@ -50,7 +51,7 @@ def test_invalid_customer_id(self): name="Jane Smith", email="jane@example.com", ip_addr="10.0.0.1", - phone="555-1234" + phone="555-1234", ) error = context.exception @@ -63,7 +64,7 @@ def test_user_dict_conversion(self): name="John Doe", email="john.doe@example.com", ip_addr="192.168.1.100", - phone="(555)-123-4567" + phone="(555)-123-4567", ) user_dict = user.dict() @@ -78,7 +79,7 @@ def test_user_json_serialization(self): name="John Doe", email="john.doe@example.com", ip_addr="192.168.1.100", - phone="(555)-123-4567" + phone="(555)-123-4567", ) json_str = user.json() @@ -97,6 +98,7 @@ class TestBasicStockTickerSpec(unittest.TestCase): def setUp(self): """Set up test fixtures - define model inline to avoid import issues.""" + class BasicStockTicker(BaseModel): symbol: str = Field(..., min_length=1, max_length=10) post_date: date @@ -119,7 +121,7 @@ def test_valid_ticker_creation(self): high=Decimal("153.75"), low=Decimal("149.80"), adj_close=Decimal("152.35"), - volume=2500000 + volume=2500000, ) self.assertEqual(ticker.symbol, "AAPL") @@ -142,7 +144,7 @@ def test_invalid_volume(self): high=Decimal("310.00"), low=Decimal("295.00"), adj_close=Decimal("304.50"), - volume=-1000 # Negative + volume=-1000, # Negative ) error = context.exception @@ -159,7 +161,7 @@ def test_invalid_negative_price(self): high=Decimal("310.00"), low=Decimal("295.00"), adj_close=Decimal("304.50"), - volume=1000000 + volume=1000000, ) error = context.exception @@ -175,7 +177,7 @@ def test_ticker_dict_conversion(self): high=Decimal("153.75"), low=Decimal("149.80"), adj_close=Decimal("152.35"), - volume=2500000 + volume=2500000, ) ticker_dict = ticker.dict() @@ -193,7 +195,7 @@ def test_ticker_json_serialization(self): high=Decimal("153.75"), low=Decimal("149.80"), adj_close=Decimal("152.35"), - volume=2500000 + volume=2500000, ) json_str = ticker.json() @@ -209,4 +211,3 @@ def test_ticker_json_serialization(self): if __name__ == "__main__": unittest.main() - diff --git a/tests/test_specs.py b/tests/test_specs.py index a44f06c6..6df7602e 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -54,45 +54,25 @@ class TestColumnDefinitionValidation: """Tests for ColumnDefinition validation""" def test_valid_primary_column(self): - col = ColumnDefinition( - name="id", - type="int", - primary=True - ) + col = ColumnDefinition(name="id", type="int", primary=True) assert col.primary assert col.type == "int" def test_primary_column_with_min_max_raises_error(self): with pytest.raises(ValueError, match="cannot have min/max options"): - ColumnDefinition( - name="id", - type="int", - primary=True, - options={"min": 1, "max": 100} - ) + ColumnDefinition(name="id", type="int", primary=True, options={"min": 1, "max": 100}) def test_primary_column_nullable_raises_error(self): with pytest.raises(ValueError, match="cannot be nullable"): - ColumnDefinition( - name="id", - type="int", - primary=True, - nullable=True - ) + ColumnDefinition(name="id", type="int", primary=True, nullable=True) def test_primary_column_without_type_raises_error(self): with pytest.raises(ValueError, match="must have a type defined"): - ColumnDefinition( - name="id", - primary=True - ) + ColumnDefinition(name="id", primary=True) def test_non_primary_column_without_type(self): # Should not raise - col = ColumnDefinition( - name="data", - options={"values": ["a", "b", "c"]} - ) + col = ColumnDefinition(name="data", options={"values": ["a", "b", "c"]}) assert col.name == "data" @@ -107,10 +87,10 @@ def test_valid_spec_passes_validation(self): columns=[ ColumnDefinition(name="id", type="int", primary=True), ColumnDefinition(name="name", type="string", options={"values": ["Alice", "Bob"]}), - ] + ], ) }, - output_destination=UCSchemaTarget(catalog="main", schema_="default") + output_destination=UCSchemaTarget(catalog="main", schema_="default"), ) result = spec.validate(strict=False) @@ -124,14 +104,7 @@ def test_empty_tables_raises_error(self): spec.validate(strict=True) def test_table_without_columns_raises_error(self): - spec = DatagenSpec( - datasets={ - "empty_table": DatasetDefinition( - number_of_rows=100, - columns=[] - ) - } - ) + spec = DatagenSpec(datasets={"empty_table": DatasetDefinition(number_of_rows=100, columns=[])}) with pytest.raises(ValueError, match="must have at least one column"): spec.validate() @@ -140,8 +113,7 @@ def test_negative_row_count_raises_error(self): spec = DatagenSpec( datasets={ "users": DatasetDefinition( - number_of_rows=-10, - columns=[ColumnDefinition(name="id", type="int", primary=True)] + number_of_rows=-10, columns=[ColumnDefinition(name="id", type="int", primary=True)] ) } ) @@ -153,8 +125,7 @@ def test_zero_row_count_raises_error(self): spec = DatagenSpec( datasets={ "users": DatasetDefinition( - number_of_rows=0, - columns=[ColumnDefinition(name="id", type="int", primary=True)] + number_of_rows=0, columns=[ColumnDefinition(name="id", type="int", primary=True)] ) } ) @@ -166,9 +137,7 @@ def test_invalid_partitions_raises_error(self): spec = DatagenSpec( datasets={ "users": DatasetDefinition( - number_of_rows=100, - partitions=-5, - columns=[ColumnDefinition(name="id", type="int", primary=True)] + number_of_rows=100, partitions=-5, columns=[ColumnDefinition(name="id", type="int", primary=True)] ) } ) @@ -185,7 +154,7 @@ def test_duplicate_column_names_raises_error(self): ColumnDefinition(name="id", type="int", primary=True), ColumnDefinition(name="duplicate", type="string"), ColumnDefinition(name="duplicate", type="int"), - ] + ], ) } ) @@ -201,7 +170,7 @@ def test_invalid_base_column_reference_raises_error(self): columns=[ ColumnDefinition(name="id", type="int", primary=True), ColumnDefinition(name="email", type="string", baseColumn="nonexistent"), - ] + ], ) } ) @@ -219,7 +188,7 @@ def test_circular_dependency_raises_error(self): ColumnDefinition(name="col_a", type="string", baseColumn="col_b"), ColumnDefinition(name="col_b", type="string", baseColumn="col_c"), ColumnDefinition(name="col_c", type="string", baseColumn="col_a"), - ] + ], ) } ) @@ -235,7 +204,7 @@ def test_multiple_primary_columns_warning(self): columns=[ ColumnDefinition(name="id1", type="int", primary=True), ColumnDefinition(name="id2", type="int", primary=True), - ] + ], ) } ) @@ -258,7 +227,7 @@ def test_column_without_type_or_options_warning(self): columns=[ ColumnDefinition(name="id", type="int", primary=True), ColumnDefinition(name="empty_col"), - ] + ], ) } ) @@ -272,8 +241,7 @@ def test_no_output_destination_warning(self): spec = DatagenSpec( datasets={ "users": DatasetDefinition( - number_of_rows=100, - columns=[ColumnDefinition(name="id", type="int", primary=True)] + number_of_rows=100, columns=[ColumnDefinition(name="id", type="int", primary=True)] ) } ) @@ -287,11 +255,10 @@ def test_unknown_generator_option_warning(self): spec = DatagenSpec( datasets={ "users": DatasetDefinition( - number_of_rows=100, - columns=[ColumnDefinition(name="id", type="int", primary=True)] + number_of_rows=100, columns=[ColumnDefinition(name="id", type="int", primary=True)] ) }, - generator_options={"unknown_option": "value"} + generator_options={"unknown_option": "value"}, ) result = spec.validate(strict=False) @@ -310,7 +277,7 @@ def test_multiple_errors_collected(self): ColumnDefinition(name="id", type="int", primary=True), ColumnDefinition(name="id", type="string"), # Error 3: duplicate ColumnDefinition(name="email", baseColumn="phone"), # Error 4: nonexistent - ] + ], ) } ) @@ -329,8 +296,7 @@ def test_strict_mode_raises_on_warnings(self): spec = DatagenSpec( datasets={ "users": DatasetDefinition( - number_of_rows=100, - columns=[ColumnDefinition(name="id", type="int", primary=True)] + number_of_rows=100, columns=[ColumnDefinition(name="id", type="int", primary=True)] ) } # No output_destination - will generate warning @@ -354,10 +320,10 @@ def test_valid_base_column_chain(self): ColumnDefinition(name="id", type="int", primary=True), ColumnDefinition(name="code", type="string", baseColumn="id"), ColumnDefinition(name="hash", type="string", baseColumn="code"), - ] + ], ) }, - output_destination=FilePathTarget(base_path="/tmp/data", output_format="parquet") + output_destination=FilePathTarget(base_path="/tmp/data", output_format="parquet"), ) result = spec.validate(strict=False) @@ -368,17 +334,13 @@ def test_multiple_tables_validation(self): spec = DatagenSpec( datasets={ "users": DatasetDefinition( - number_of_rows=100, - columns=[ColumnDefinition(name="id", type="int", primary=True)] + number_of_rows=100, columns=[ColumnDefinition(name="id", type="int", primary=True)] ), "orders": DatasetDefinition( number_of_rows=-50, # Error in second table - columns=[ColumnDefinition(name="order_id", type="int", primary=True)] + columns=[ColumnDefinition(name="order_id", type="int", primary=True)], ), - "products": DatasetDefinition( - number_of_rows=200, - columns=[] # Error: no columns - ) + "products": DatasetDefinition(number_of_rows=200, columns=[]), # Error: no columns } ) @@ -429,38 +391,24 @@ def test_realistic_valid_spec(self): partitions=4, columns=[ ColumnDefinition(name="user_id", type="int", primary=True), - ColumnDefinition(name="username", type="string", options={ - "template": r"\w{8,12}" - }), - ColumnDefinition(name="email", type="string", options={ - "template": r"\w.\w@\w.com" - }), - ColumnDefinition(name="age", type="int", options={ - "min": 18, "max": 99 - }), - ] + ColumnDefinition(name="username", type="string", options={"template": r"\w{8,12}"}), + ColumnDefinition(name="email", type="string", options={"template": r"\w.\w@\w.com"}), + ColumnDefinition(name="age", type="int", options={"min": 18, "max": 99}), + ], ), "orders": DatasetDefinition( number_of_rows=5000, columns=[ ColumnDefinition(name="order_id", type="int", primary=True), - ColumnDefinition(name="amount", type="decimal", options={ - "min": 10.0, "max": 1000.0 - }), - ] - ) + ColumnDefinition(name="amount", type="decimal", options={"min": 10.0, "max": 1000.0}), + ], + ), }, - output_destination=UCSchemaTarget( - catalog="main", - schema_="synthetic_data" - ), - generator_options={ - "random": True, - "randomSeed": 42 - } + output_destination=UCSchemaTarget(catalog="main", schema_="synthetic_data"), + generator_options={"random": True, "randomSeed": 42}, ) result = spec.validate(strict=True) assert result.is_valid() assert len(result.errors) == 0 - assert len(result.warnings) == 0 \ No newline at end of file + assert len(result.warnings) == 0 From 59bd52372cf2f07ba5cd216eb6055ac200f8759d Mon Sep 17 00:00:00 2001 From: Anup Kalburgi Date: Thu, 11 Dec 2025 13:03:39 -0500 Subject: [PATCH 20/20] fixing the makefile and the project toml after merge --- makefile | 6 +++--- pyproject.toml | 6 +++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/makefile b/makefile index e8b81031..086821dc 100644 --- a/makefile +++ b/makefile @@ -3,7 +3,7 @@ all: clean dev lint fmt test clean: - rm -fr .venv clean htmlcov .mypy_cache .pytest_cache .ruff_cache .coverage coverage.xml + rm -fr clean htmlcov .mypy_cache .pytest_cache .ruff_cache .coverage coverage.xml rm -fr **/*.pyc dev: @@ -11,10 +11,10 @@ dev: @hatch run which python lint: - hatch run test-pydantic.2.8.2:verify + hatch run dev:verify fmt: - hatch run test-pydantic.2.8.2:fmt + hatch run dev:fmt test: hatch run test-pydantic:test diff --git a/pyproject.toml b/pyproject.toml index 955e668e..0298bf7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,7 +112,11 @@ dependencies = [ ] python="3.10" -path = ".venv" +[tool.hatch.envs.dev] +template = "default" +extra-dependencies = [ + "pydantic~=2.8.2" +] [tool.hatch.envs.default.scripts] test = "pytest tests/ -n 10 --cov --cov-report=html --timeout 600 --durations 20"