diff --git a/src/data_designer/config/analysis/__init__.py b/src/data_designer/config/analysis/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/config/analysis/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/config/analysis/column_profilers.py b/src/data_designer/config/analysis/column_profilers.py index cbb292a0..0a236782 100644 --- a/src/data_designer/config/analysis/column_profilers.py +++ b/src/data_designer/config/analysis/column_profilers.py @@ -3,11 +3,12 @@ from abc import ABC from enum import Enum -from typing import TypeAlias +from typing import Optional, Union from pydantic import BaseModel, Field from rich.panel import Panel from rich.table import Column, Table +from typing_extensions import TypeAlias from ..base import ConfigBase from ..utils.visualization import ColorPalette @@ -37,20 +38,20 @@ def create_report_section(self) -> Panel: class JudgeScoreProfilerConfig(ConfigBase): model_alias: str - summary_score_sample_size: int | None = Field(default=20, ge=1) + summary_score_sample_size: Optional[int] = Field(default=20, ge=1) class JudgeScoreSample(BaseModel): - score: int | str + score: Union[int, str] reasoning: str class JudgeScoreDistributions(BaseModel): - scores: dict[str, list[int | str]] + scores: dict[str, list[Union[int, str]]] reasoning: dict[str, list[str]] distribution_types: dict[str, ColumnDistributionType] - distributions: dict[str, CategoricalDistribution | NumericalDistribution | MissingValue] - histograms: dict[str, CategoricalHistogramData | MissingValue] + distributions: dict[str, Union[CategoricalDistribution, NumericalDistribution, MissingValue]] + histograms: dict[str, Union[CategoricalHistogramData, MissingValue]] class JudgeScoreSummary(BaseModel): @@ -62,7 +63,7 @@ class JudgeScoreSummary(BaseModel): class JudgeScoreProfilerResults(ColumnProfilerResults): column_name: str summaries: dict[str, JudgeScoreSummary] - score_distributions: JudgeScoreDistributions | MissingValue + score_distributions: Union[JudgeScoreDistributions, MissingValue] def create_report_section(self) -> Panel: layout = Table.grid(Column(), expand=True, padding=(2, 0)) diff --git a/src/data_designer/config/analysis/column_statistics.py b/src/data_designer/config/analysis/column_statistics.py index 49d6530a..991e41b9 100644 --- a/src/data_designer/config/analysis/column_statistics.py +++ b/src/data_designer/config/analysis/column_statistics.py @@ -5,11 +5,11 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Annotated, Any, Literal, TypeAlias +from typing import Annotated, Any, Literal, Optional, Union from pandas import Series from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from typing_extensions import Self +from typing_extensions import Self, TypeAlias from ..columns import DataDesignerColumnType from ..sampler_params import SamplerType @@ -39,19 +39,19 @@ def create_report_row_data(self) -> dict[str, str]: ... class GeneralColumnStatistics(BaseColumnStatistics): column_name: str - num_records: int | MissingValue - num_null: int | MissingValue - num_unique: int | MissingValue + num_records: Union[int, MissingValue] + num_null: Union[int, MissingValue] + num_unique: Union[int, MissingValue] pyarrow_dtype: str simple_dtype: str column_type: Literal["general"] = "general" @field_validator("num_null", "num_unique", "num_records", mode="before") - def general_statistics_ensure_python_integers(cls, v: int | MissingValue) -> int | MissingValue: + def general_statistics_ensure_python_integers(cls, v: Union[int, MissingValue]) -> Union[int, MissingValue]: return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, int) @property - def percent_null(self) -> float | MissingValue: + def percent_null(self) -> Union[float, MissingValue]: return ( self.num_null if self._is_missing_value(self.num_null) @@ -59,7 +59,7 @@ def percent_null(self) -> float | MissingValue: ) @property - def percent_unique(self) -> float | MissingValue: + def percent_unique(self) -> Union[float, MissingValue]: return ( self.num_unique if self._is_missing_value(self.num_unique) @@ -78,17 +78,17 @@ def _general_display_row(self) -> dict[str, str]: def create_report_row_data(self) -> dict[str, str]: return self._general_display_row - def _is_missing_value(self, v: float | int | MissingValue) -> bool: + def _is_missing_value(self, v: Union[float, int, MissingValue]) -> bool: return v in set(MissingValue) class LLMTextColumnStatistics(GeneralColumnStatistics): - completion_tokens_mean: float | MissingValue - completion_tokens_median: float | MissingValue - completion_tokens_stddev: float | MissingValue - prompt_tokens_mean: float | MissingValue - prompt_tokens_median: float | MissingValue - prompt_tokens_stddev: float | MissingValue + completion_tokens_mean: Union[float, MissingValue] + completion_tokens_median: Union[float, MissingValue] + completion_tokens_stddev: Union[float, MissingValue] + prompt_tokens_mean: Union[float, MissingValue] + prompt_tokens_median: Union[float, MissingValue] + prompt_tokens_stddev: Union[float, MissingValue] column_type: Literal[DataDesignerColumnType.LLM_TEXT.value] = DataDesignerColumnType.LLM_TEXT.value @field_validator( @@ -100,7 +100,7 @@ class LLMTextColumnStatistics(GeneralColumnStatistics): "prompt_tokens_stddev", mode="before", ) - def llm_column_ensure_python_floats(cls, v: float | int | MissingValue) -> float | int | MissingValue: + def llm_column_ensure_python_floats(cls, v: Union[float, int, MissingValue]) -> Union[float, int, MissingValue]: return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, float) def create_report_row_data(self) -> dict[str, Any]: @@ -136,7 +136,7 @@ class LLMJudgedColumnStatistics(LLMTextColumnStatistics): class SamplerColumnStatistics(GeneralColumnStatistics): sampler_type: SamplerType distribution_type: ColumnDistributionType - distribution: CategoricalDistribution | NumericalDistribution | MissingValue | None + distribution: Optional[Union[CategoricalDistribution, NumericalDistribution, MissingValue]] column_type: Literal[DataDesignerColumnType.SAMPLER.value] = DataDesignerColumnType.SAMPLER.value def create_report_row_data(self) -> dict[str, str]: @@ -148,7 +148,7 @@ def create_report_row_data(self) -> dict[str, str]: class SeedDatasetColumnStatistics(GeneralColumnStatistics): distribution_type: ColumnDistributionType - distribution: CategoricalDistribution | NumericalDistribution | MissingValue | None + distribution: Optional[Union[CategoricalDistribution, NumericalDistribution, MissingValue]] column_type: Literal[DataDesignerColumnType.SEED_DATASET.value] = DataDesignerColumnType.SEED_DATASET.value def create_report_row_data(self) -> dict[str, str]: @@ -160,15 +160,15 @@ class ExpressionColumnStatistics(GeneralColumnStatistics): class ValidationColumnStatistics(GeneralColumnStatistics): - num_valid_records: int | MissingValue + num_valid_records: Union[int, MissingValue] column_type: Literal[DataDesignerColumnType.VALIDATION.value] = DataDesignerColumnType.VALIDATION.value @field_validator("num_valid_records", mode="before") - def code_validation_column_ensure_python_integers(cls, v: int | MissingValue) -> int | MissingValue: + def code_validation_column_ensure_python_integers(cls, v: Union[int, MissingValue]) -> Union[int, MissingValue]: return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, int) @property - def percent_valid(self) -> float | MissingValue: + def percent_valid(self) -> Union[float, MissingValue]: return ( self.num_valid_records if self._is_missing_value(self.num_valid_records) @@ -181,7 +181,7 @@ def create_report_row_data(self) -> dict[str, str]: class CategoricalHistogramData(BaseModel): - categories: list[float | int | str] + categories: list[Union[float, int, str]] counts: list[int] @model_validator(mode="after") @@ -198,12 +198,12 @@ def from_series(cls, series: Series) -> Self: class CategoricalDistribution(BaseModel): - most_common_value: str | int - least_common_value: str | int + most_common_value: Union[str, int] + least_common_value: Union[str, int] histogram: CategoricalHistogramData @field_validator("most_common_value", "least_common_value", mode="before") - def ensure_python_types(cls, v: str | int) -> str | int: + def ensure_python_types(cls, v: Union[str, int]) -> Union[str, int]: return str(v) if not is_int(v) else prepare_number_for_reporting(v, int) @classmethod @@ -217,14 +217,14 @@ def from_series(cls, series: Series) -> Self: class NumericalDistribution(BaseModel): - min: float | int - max: float | int + min: Union[float, int] + max: Union[float, int] mean: float stddev: float median: float @field_validator("min", "max", "mean", "stddev", "median", mode="before") - def ensure_python_types(cls, v: float | int) -> float | int: + def ensure_python_types(cls, v: Union[float, int]) -> Union[float, int]: return prepare_number_for_reporting(v, int if is_int(v) else float) @classmethod @@ -239,14 +239,16 @@ def from_series(cls, series: Series) -> Self: ColumnStatisticsT: TypeAlias = Annotated[ - GeneralColumnStatistics - | LLMTextColumnStatistics - | LLMCodeColumnStatistics - | LLMStructuredColumnStatistics - | LLMJudgedColumnStatistics - | SamplerColumnStatistics - | SeedDatasetColumnStatistics - | ValidationColumnStatistics - | ExpressionColumnStatistics, + Union[ + GeneralColumnStatistics, + LLMTextColumnStatistics, + LLMCodeColumnStatistics, + LLMStructuredColumnStatistics, + LLMJudgedColumnStatistics, + SamplerColumnStatistics, + SeedDatasetColumnStatistics, + ValidationColumnStatistics, + ExpressionColumnStatistics, + ], Field(discriminator="column_type"), ] diff --git a/src/data_designer/config/analysis/dataset_profiler.py b/src/data_designer/config/analysis/dataset_profiler.py index 07647768..f9e0e168 100644 --- a/src/data_designer/config/analysis/dataset_profiler.py +++ b/src/data_designer/config/analysis/dataset_profiler.py @@ -3,6 +3,7 @@ from functools import cached_property from pathlib import Path +from typing import Optional, Union from pydantic import BaseModel, Field, field_validator @@ -18,8 +19,8 @@ class DatasetProfilerResults(BaseModel): num_records: int target_num_records: int column_statistics: list[ColumnStatisticsT] = Field(..., min_length=1) - side_effect_column_names: list[str] | None = None - column_profiles: list[ColumnProfilerResultsT] | None = None + side_effect_column_names: Optional[list[str]] = None + column_profiles: Optional[list[ColumnProfilerResultsT]] = None @field_validator("num_records", "target_num_records", mode="before") def ensure_python_integers(cls, v: int) -> int: @@ -42,8 +43,8 @@ def get_column_statistics_by_type(self, column_type: DataDesignerColumnType) -> def to_report( self, - save_path: str | Path | None = None, - include_sections: list[ReportSection | DataDesignerColumnType] | None = None, + save_path: Optional[Union[str, Path]] = None, + include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None, ) -> None: """Generate and print an analysis report based on the dataset profiling results. diff --git a/src/data_designer/config/analysis/utils/__init__.py b/src/data_designer/config/analysis/utils/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/config/analysis/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/config/analysis/utils/reporting.py b/src/data_designer/config/analysis/utils/reporting.py index 35cd7687..ca62a0cd 100644 --- a/src/data_designer/config/analysis/utils/reporting.py +++ b/src/data_designer/config/analysis/utils/reporting.py @@ -5,7 +5,7 @@ from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Union from rich.align import Align from rich.console import Console, Group @@ -49,8 +49,8 @@ class ReportSection(str, Enum): def generate_analysis_report( analysis: DatasetProfilerResults, - save_path: str | Path | None = None, - include_sections: list[ReportSection | DataDesignerColumnType] | None = None, + save_path: Optional[Union[str, Path]] = None, + include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None, ) -> None: """Generate an analysis report for dataset profiling results. @@ -166,7 +166,7 @@ def create_judge_score_summary_table( layout = Table.grid(Column(), Column(), expand=True, padding=(0, 2)) histogram_table = create_rich_histogram_table( - {str(s): c for s, c in zip(histogram.categories, histogram.counts, strict=False)}, + {str(s): c for s, c in zip(histogram.categories, histogram.counts)}, ("score", "count"), name_style=HIST_NAME_STYLE, value_style=HIST_VALUE_STYLE, diff --git a/src/data_designer/config/base.py b/src/data_designer/config/base.py index 4e4f0f66..64d4d6de 100644 --- a/src/data_designer/config/base.py +++ b/src/data_designer/config/base.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar, Union import pandas as pd from pydantic import BaseModel, ConfigDict @@ -14,9 +14,9 @@ from .utils.io_helpers import serialize_data if TYPE_CHECKING: - from ..client.results.preview import PreviewResults from .analysis.dataset_profiler import DatasetProfilerResults from .config_builder import DataDesignerConfigBuilder + from .preview_results import PreviewResults DEFAULT_NUM_RECORDS = 10 @@ -66,7 +66,7 @@ def to_dict(self) -> dict[str, Any]: """ return self.model_dump(mode="json") - def to_yaml(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None: + def to_yaml(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]: """Convert the configuration to a YAML string or file. Args: @@ -84,7 +84,7 @@ def to_yaml(self, path: str | Path | None = None, *, indent: int | None = 2, **k with open(path, "w") as f: f.write(yaml_str) - def to_json(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None: + def to_json(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]: """Convert the configuration to a JSON string or file. Args: diff --git a/src/data_designer/config/columns.py b/src/data_designer/config/columns.py index 5eada374..c2449499 100644 --- a/src/data_designer/config/columns.py +++ b/src/data_designer/config/columns.py @@ -3,10 +3,10 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Literal, TypeAlias +from typing import Literal, Optional, Type, Union from pydantic import BaseModel, Field, model_validator -from typing_extensions import Self +from typing_extensions import Self, TypeAlias from .base import ConfigBase from .errors import InvalidColumnTypeError, InvalidConfigError @@ -79,7 +79,7 @@ class SamplerColumnConfig(SingleColumnConfig): sampler_type: SamplerType params: SamplerParamsT conditional_params: dict[str, SamplerParamsT] = {} - convert_to: str | None = None + convert_to: Optional[str] = None @property def column_type(self) -> DataDesignerColumnType: @@ -89,8 +89,8 @@ def column_type(self) -> DataDesignerColumnType: class LLMTextColumnConfig(SingleColumnConfig): prompt: str model_alias: str - system_prompt: str | None = None - multi_modal_context: list[ImageContext] | None = None + system_prompt: Optional[str] = None + multi_modal_context: Optional[list[ImageContext]] = None @property def column_type(self) -> DataDesignerColumnType: @@ -124,7 +124,7 @@ def column_type(self) -> DataDesignerColumnType: class LLMStructuredColumnConfig(LLMTextColumnConfig): - output_format: dict | type[BaseModel] + output_format: Union[dict, Type[BaseModel]] @property def column_type(self) -> DataDesignerColumnType: @@ -140,7 +140,7 @@ def validate_output_format(self) -> Self: class Score(ConfigBase): name: str = Field(..., description="A clear name for this score.") description: str = Field(..., description="An informative and detailed assessment guide for using this score.") - options: dict[int | str, str] = Field(..., description="Score options in the format of {score: description}.") + options: dict[Union[int, str], str] = Field(..., description="Score options in the format of {score: description}.") class LLMJudgeColumnConfig(LLMTextColumnConfig): @@ -208,16 +208,16 @@ def column_type(self) -> DataDesignerColumnType: } -ColumnConfigT: TypeAlias = ( - ExpressionColumnConfig - | LLMCodeColumnConfig - | LLMJudgeColumnConfig - | LLMStructuredColumnConfig - | LLMTextColumnConfig - | SamplerColumnConfig - | SeedDatasetColumnConfig - | ValidationColumnConfig -) +ColumnConfigT: TypeAlias = Union[ + ExpressionColumnConfig, + LLMCodeColumnConfig, + LLMJudgeColumnConfig, + LLMStructuredColumnConfig, + LLMTextColumnConfig, + SamplerColumnConfig, + SeedDatasetColumnConfig, + ValidationColumnConfig, +] def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT: @@ -234,19 +234,19 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType column_type = resolve_string_enum(column_type, DataDesignerColumnType) if column_type == DataDesignerColumnType.LLM_TEXT: return LLMTextColumnConfig(name=name, **kwargs) - if column_type == DataDesignerColumnType.LLM_CODE: + elif column_type == DataDesignerColumnType.LLM_CODE: return LLMCodeColumnConfig(name=name, **kwargs) - if column_type == DataDesignerColumnType.LLM_STRUCTURED: + elif column_type == DataDesignerColumnType.LLM_STRUCTURED: return LLMStructuredColumnConfig(name=name, **kwargs) - if column_type == DataDesignerColumnType.LLM_JUDGE: + elif column_type == DataDesignerColumnType.LLM_JUDGE: return LLMJudgeColumnConfig(name=name, **kwargs) - if column_type == DataDesignerColumnType.VALIDATION: + elif column_type == DataDesignerColumnType.VALIDATION: return ValidationColumnConfig(name=name, **kwargs) - if column_type == DataDesignerColumnType.EXPRESSION: + elif column_type == DataDesignerColumnType.EXPRESSION: return ExpressionColumnConfig(name=name, **kwargs) - if column_type == DataDesignerColumnType.SAMPLER: + elif column_type == DataDesignerColumnType.SAMPLER: return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs)) - if column_type == DataDesignerColumnType.SEED_DATASET: + elif column_type == DataDesignerColumnType.SEED_DATASET: return SeedDatasetColumnConfig(name=name, **kwargs) raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index 78b2195d..a33b3658 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -6,6 +6,7 @@ import json import logging from pathlib import Path +from typing import Optional, Union from pygments import highlight from pygments.formatters import HtmlFormatter @@ -16,9 +17,11 @@ from .base import ExportableConfigBase from .columns import ColumnConfigT, DataDesignerColumnType, SeedDatasetColumnConfig, get_column_config_from_kwargs from .data_designer_config import DataDesignerConfig +from .dataset_builders import BuildStage from .datastore import DatastoreSettings, fetch_seed_dataset_column_names from .errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError from .models import ModelConfig, load_model_configs +from .processors import ProcessorConfig, ProcessorType, get_processor_config_from_kwargs from .sampler_constraints import ( ColumnConstraintT, ColumnInequalityConstraint, @@ -60,7 +63,7 @@ class BuilderConfig(ExportableConfigBase): """ data_designer: DataDesignerConfig - datastore_settings: DatastoreSettings | None + datastore_settings: Optional[DatastoreSettings] class DataDesignerConfigBuilder: @@ -70,7 +73,7 @@ class DataDesignerConfigBuilder: """ @classmethod - def from_config(cls, config: dict | str | Path | BuilderConfig) -> Self: + def from_config(cls, config: Union[dict, str, Path, BuilderConfig]) -> Self: """Create a DataDesignerConfigBuilder from an existing configuration. Args: @@ -117,7 +120,7 @@ def from_config(cls, config: dict | str | Path | BuilderConfig) -> Self: return builder - def __init__(self, model_configs: list[ModelConfig] | str | Path | None = None): + def __init__(self, model_configs: Optional[Union[list[ModelConfig], str, Path]] = None): """Initialize a new DataDesignerConfigBuilder instance. Args: @@ -128,11 +131,12 @@ def __init__(self, model_configs: list[ModelConfig] | str | Path | None = None): """ self._column_configs = {} self._model_configs = load_model_configs(model_configs) - self._seed_config: SeedConfig | None = None + self._processor_configs: list[ProcessorConfig] = [] + self._seed_config: Optional[SeedConfig] = None self._constraints: list[ColumnConstraintT] = [] self._profilers: list[ColumnProfilerConfigT] = [] self._info = DataDesignerInfo() - self._datastore_settings: DatastoreSettings | None = None + self._datastore_settings: Optional[DatastoreSettings] = None @property def model_configs(self) -> list[ModelConfig]: @@ -193,10 +197,10 @@ def delete_model_config(self, alias: str) -> Self: def add_column( self, - column_config: ColumnConfigT | None = None, + column_config: Optional[ColumnConfigT] = None, *, - name: str | None = None, - column_type: DataDesignerColumnType | None = None, + name: Optional[str] = None, + column_type: Optional[DataDesignerColumnType] = None, **kwargs, ) -> Self: """Add a Data Designer column configuration to the current Data Designer configuration. @@ -233,9 +237,9 @@ def add_column( def add_constraint( self, - constraint: ColumnConstraintT | None = None, + constraint: Optional[ColumnConstraintT] = None, *, - constraint_type: ConstraintType | None = None, + constraint_type: Optional[ConstraintType] = None, **kwargs, ) -> Self: """Add a constraint to the current Data Designer configuration. @@ -283,6 +287,43 @@ def add_constraint( self._constraints.append(constraint) return self + def add_processor( + self, + processor_config: Optional[ProcessorConfig] = None, + *, + processor_type: Optional[ProcessorType] = None, + **kwargs, + ) -> Self: + """Add a processor to the current Data Designer configuration. + + You can either provide a processor config object directly, or provide a processor type and + additional keyword arguments to construct the processor config object. + + Args: + processor_config: The processor configuration object to add. + processor_type: The type of processor to add. + **kwargs: Additional keyword arguments to pass to the processor constructor. + + Returns: + The current Data Designer config builder instance. + """ + if processor_config is None: + if processor_type is None: + raise BuilderConfigurationError( + "🛑 You must provide either a 'processor_config' object or 'processor_type' " + "with additional keyword arguments." + ) + processor_config = get_processor_config_from_kwargs(processor_type=processor_type, **kwargs) + + # Checks elsewhere fail if DropColumnsProcessor drops a column but it is not marked for drop + if processor_config.processor_type == ProcessorType.DROP_COLUMNS: + for column in processor_config.column_names: + if column in self._column_configs: + self._column_configs[column].drop = True + + self._processor_configs.append(processor_config) + return self + def add_profiler(self, profiler_config: ColumnProfilerConfigT) -> Self: """Add a profiler to the current Data Designer configuration. @@ -331,6 +372,7 @@ def build(self, *, skip_validation: bool = False, raise_exceptions: bool = False columns=list(self._column_configs.values()), constraints=self._constraints or None, profilers=self._profilers or None, + processors=self._processor_configs or None, ) def delete_constraints(self, target_column: str) -> Self: @@ -427,7 +469,15 @@ def get_columns_excluding_type(self, column_type: DataDesignerColumnType) -> lis column_type = resolve_string_enum(column_type, DataDesignerColumnType) return [c for c in self._column_configs.values() if c.column_type != column_type] - def get_seed_config(self) -> SeedConfig | None: + def get_processor_configs(self) -> dict[BuildStage, list[ProcessorConfig]]: + """Get processor configuration objects. + + Returns: + A dictionary of processor configuration objects by dataset builder stage. + """ + return self._processor_configs + + def get_seed_config(self) -> Optional[SeedConfig]: """Get the seed config for the current Data Designer configuration. Returns: @@ -435,7 +485,7 @@ def get_seed_config(self) -> SeedConfig | None: """ return self._seed_config - def get_seed_datastore_settings(self) -> DatastoreSettings | None: + def get_seed_datastore_settings(self) -> Optional[DatastoreSettings]: """Get most recent datastore settings for the current Data Designer configuration. Returns: @@ -454,7 +504,7 @@ def num_columns_of_type(self, column_type: DataDesignerColumnType) -> int: """ return len(self.get_columns_of_type(column_type)) - def set_seed_datastore_settings(self, datastore_settings: DatastoreSettings | None) -> Self: + def set_seed_datastore_settings(self, datastore_settings: Optional[DatastoreSettings]) -> Self: """Set the datastore settings for the seed dataset. Args: @@ -477,7 +527,9 @@ def validate(self, *, raise_exceptions: bool = False) -> Self: """ violations = validate_data_designer_config( - columns=list(self._column_configs.values()), allowed_references=self.allowed_references + columns=list(self._column_configs.values()), + processor_configs=self._processor_configs, + allowed_references=self.allowed_references, ) rich_print_violations(violations) if raise_exceptions and len([v for v in violations if v.level == ViolationLevel.ERROR]) > 0: @@ -516,7 +568,7 @@ def with_seed_dataset( self._column_configs[column_name] = SeedDatasetColumnConfig(name=column_name) return self - def write_config(self, path: str | Path, indent: int | None = 2, **kwargs) -> None: + def write_config(self, path: Union[str, Path], indent: Optional[int] = 2, **kwargs) -> None: """Write the current configuration to a file. Args: diff --git a/src/data_designer/config/data_designer_config.py b/src/data_designer/config/data_designer_config.py index 8579ecd9..ba717fd8 100644 --- a/src/data_designer/config/data_designer_config.py +++ b/src/data_designer/config/data_designer_config.py @@ -3,12 +3,15 @@ from __future__ import annotations +from typing import Optional + from pydantic import Field from .analysis.column_profilers import ColumnProfilerConfigT from .base import ExportableConfigBase from .columns import ColumnConfigT from .models import ModelConfig +from .processors import ProcessorConfig from .sampler_constraints import ColumnConstraintT from .seed import SeedConfig @@ -30,7 +33,8 @@ class DataDesignerConfig(ExportableConfigBase): """ columns: list[ColumnConfigT] = Field(min_length=1) - model_configs: list[ModelConfig] | None = None - seed_config: SeedConfig | None = None - constraints: list[ColumnConstraintT] | None = None - profilers: list[ColumnProfilerConfigT] | None = None + model_configs: Optional[list[ModelConfig]] = None + seed_config: Optional[SeedConfig] = None + constraints: Optional[list[ColumnConstraintT]] = None + profilers: Optional[list[ColumnProfilerConfigT]] = None + processors: Optional[list[ProcessorConfig]] = None diff --git a/src/data_designer/config/dataset_builders.py b/src/data_designer/config/dataset_builders.py new file mode 100644 index 00000000..aa18df7f --- /dev/null +++ b/src/data_designer/config/dataset_builders.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from enum import Enum + + +class BuildStage(str, Enum): + PRE_BATCH = "pre_batch" + POST_BATCH = "post_batch" + PRE_GENERATION = "pre_generation" + POST_GENERATION = "post_generation" diff --git a/src/data_designer/config/datastore.py b/src/data_designer/config/datastore.py index fbf17eed..d1ea0c07 100644 --- a/src/data_designer/config/datastore.py +++ b/src/data_designer/config/datastore.py @@ -5,7 +5,7 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Union from huggingface_hub import HfApi, HfFileSystem import pandas as pd @@ -28,17 +28,18 @@ class DatastoreSettings(BaseModel): ..., description="Datastore endpoint. Use 'https://huggingface.co' for the Hugging Face Hub.", ) - token: str | None = Field(default=None, description="If needed, token to use for authentication.") + token: Optional[str] = Field(default=None, description="If needed, token to use for authentication.") -def get_file_column_names(file_path: str | Path, file_type: str) -> list[str]: +def get_file_column_names(file_path: Union[str, Path], file_type: str) -> list[str]: """Extract column names based on file type.""" if file_type == "parquet": try: schema = pq.read_schema(file_path) if hasattr(schema, "names"): return schema.names - return [field.name for field in schema] + else: + return [field.name for field in schema] except Exception as e: logger.warning(f"Failed to process parquet file {file_path}: {e}") return [] @@ -70,13 +71,16 @@ def resolve_datastore_settings(datastore_settings: DatastoreSettings | dict | No raise InvalidConfigError("🛑 Datastore settings are required in order to upload datasets to the datastore.") if isinstance(datastore_settings, DatastoreSettings): return datastore_settings - if isinstance(datastore_settings, dict): + elif isinstance(datastore_settings, dict): return DatastoreSettings.model_validate(datastore_settings) - raise InvalidConfigError("🛑 Invalid datastore settings format. Must be DatastoreSettings object or dictionary.") + else: + raise InvalidConfigError( + "🛑 Invalid datastore settings format. Must be DatastoreSettings object or dictionary." + ) def upload_to_hf_hub( - dataset_path: str | Path, + dataset_path: Union[str, Path], filename: str, repo_id: str, datastore_settings: DatastoreSettings, @@ -105,7 +109,7 @@ def upload_to_hf_hub( def _fetch_seed_dataset_column_names_from_datastore( repo_id: str, filename: str, - datastore_settings: DatastoreSettings | dict | None = None, + datastore_settings: Optional[Union[DatastoreSettings, dict]] = None, ) -> list[str]: file_type = filename.split(".")[-1] if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS: @@ -123,7 +127,7 @@ def _fetch_seed_dataset_column_names_from_local_file(dataset_path: str | Path) - return get_file_column_names(dataset_path, dataset_path.suffix.lower()[1:]) -def _validate_dataset_path(dataset_path: str | Path) -> Path: +def _validate_dataset_path(dataset_path: Union[str, Path]) -> Path: if not Path(dataset_path).is_file(): raise InvalidFilePathError("🛑 To upload a dataset to the datastore, you must provide a valid file path.") if not Path(dataset_path).name.endswith(tuple(VALID_DATASET_FILE_EXTENSIONS)): diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index b6bf26d6..4fedb0c9 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -4,11 +4,11 @@ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Any, Generic, TypeAlias, TypeVar +from typing import Any, Generic, List, Optional, TypeVar, Union import numpy as np from pydantic import BaseModel, Field, model_validator -from typing_extensions import Self +from typing_extensions import Self, TypeAlias from .base import ConfigBase from .errors import InvalidConfigError @@ -49,7 +49,7 @@ def get_context(self, record: dict) -> dict[str, Any]: ... class ImageContext(ModalityContext): modality: Modality = Modality.IMAGE - image_format: ImageFormat | None = None + image_format: Optional[ImageFormat] = None def get_context(self, record: dict) -> dict[str, Any]: context = dict(type="image_url") @@ -82,8 +82,8 @@ def sample(self) -> float: ... class ManualDistributionParams(ConfigBase): - values: list[float] = Field(min_length=1) - weights: list[float] | None = None + values: List[float] = Field(min_length=1) + weights: Optional[List[float]] = None @model_validator(mode="after") def _normalize_weights(self) -> Self: @@ -99,7 +99,7 @@ def _validate_equal_lengths(self) -> Self: class ManualDistribution(Distribution[ManualDistributionParams]): - distribution_type: DistributionType | None = "manual" + distribution_type: Optional[DistributionType] = "manual" params: ManualDistributionParams def sample(self) -> float: @@ -118,26 +118,26 @@ def _validate_low_lt_high(self) -> Self: class UniformDistribution(Distribution[UniformDistributionParams]): - distribution_type: DistributionType | None = "uniform" + distribution_type: Optional[DistributionType] = "uniform" params: UniformDistributionParams def sample(self) -> float: return float(np.random.uniform(low=self.params.low, high=self.params.high, size=1)[0]) -DistributionT: TypeAlias = UniformDistribution | ManualDistribution +DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution] class InferenceParameters(ConfigBase): - temperature: float | DistributionT | None = None - top_p: float | DistributionT | None = None - max_tokens: int | None = Field(default=None, ge=1) + temperature: Optional[Union[float, DistributionT]] = None + top_p: Optional[Union[float, DistributionT]] = None + max_tokens: Optional[int] = Field(default=None, ge=1) max_parallel_requests: int = Field(default=4, ge=1) - timeout: int | None = Field(default=None, ge=1) - extra_body: dict[str, Any] | None = None + timeout: Optional[int] = Field(default=None, ge=1) + extra_body: Optional[dict[str, Any]] = None @property - def generate_kwargs(self) -> dict[str, float | int]: + def generate_kwargs(self) -> dict[str, Union[float, int]]: result = {} if self.temperature is not None: result["temperature"] = ( @@ -173,7 +173,7 @@ def _validate_top_p(self) -> Self: def _run_validation( self, - value: float | DistributionT | None, + value: Union[float, DistributionT, None], param_name: str, min_value: float, max_value: float, @@ -201,10 +201,10 @@ class ModelConfig(ConfigBase): alias: str model: str inference_parameters: InferenceParameters - provider: str | None = None + provider: Optional[str] = None -def load_model_configs(model_configs: list[ModelConfig] | str | Path | None) -> list[ModelConfig]: +def load_model_configs(model_configs: Union[list[ModelConfig], str, Path, None]) -> list[ModelConfig]: if model_configs is None: return [] if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs): diff --git a/src/data_designer/config/preview_results.py b/src/data_designer/config/preview_results.py index ffe7ba20..41313c52 100644 --- a/src/data_designer/config/preview_results.py +++ b/src/data_designer/config/preview_results.py @@ -3,6 +3,8 @@ from __future__ import annotations +from typing import Optional + import pandas as pd from .analysis.dataset_profiler import DatasetProfilerResults @@ -15,8 +17,8 @@ def __init__( self, *, config_builder: DataDesignerConfigBuilder, - dataset: pd.DataFrame | None = None, - analysis: DatasetProfilerResults | None = None, + dataset: Optional[pd.DataFrame] = None, + analysis: Optional[DatasetProfilerResults] = None, ): """Creates a new instance with results from a Data Designer preview run. diff --git a/src/data_designer/config/processors.py b/src/data_designer/config/processors.py new file mode 100644 index 00000000..163956a3 --- /dev/null +++ b/src/data_designer/config/processors.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC +from enum import Enum +from typing import Literal + +from pydantic import Field, field_validator + +from .base import ConfigBase +from .dataset_builders import BuildStage + +SUPPORTED_STAGES = [BuildStage.POST_BATCH] + + +class ProcessorType(str, Enum): + DROP_COLUMNS = "drop_columns" + + +class ProcessorConfig(ConfigBase, ABC): + build_stage: BuildStage = Field( + ..., description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}" + ) + + @field_validator("build_stage") + def validate_build_stage(cls, v: BuildStage) -> BuildStage: + if v not in SUPPORTED_STAGES: + raise ValueError( + f"Invalid dataset builder stage: {v}. Only these stages are supported: {', '.join(SUPPORTED_STAGES)}" + ) + return v + + +def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs) -> ProcessorConfig: + if processor_type == ProcessorType.DROP_COLUMNS: + return DropColumnsProcessorConfig(**kwargs) + + +class DropColumnsProcessorConfig(ProcessorConfig): + column_names: list[str] + processor_type: Literal[ProcessorType.DROP_COLUMNS] = ProcessorType.DROP_COLUMNS diff --git a/src/data_designer/config/sampler_constraints.py b/src/data_designer/config/sampler_constraints.py index d5036616..9dea2bae 100644 --- a/src/data_designer/config/sampler_constraints.py +++ b/src/data_designer/config/sampler_constraints.py @@ -3,7 +3,9 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import TypeAlias +from typing import Union + +from typing_extensions import TypeAlias from .base import ConfigBase @@ -46,4 +48,4 @@ def constraint_type(self) -> ConstraintType: return ConstraintType.COLUMN_INEQUALITY -ColumnConstraintT: TypeAlias = ScalarInequalityConstraint | ColumnInequalityConstraint +ColumnConstraintT: TypeAlias = Union[ScalarInequalityConstraint, ColumnInequalityConstraint] diff --git a/src/data_designer/config/sampler_params.py b/src/data_designer/config/sampler_params.py index 6f3abf9e..03b278e2 100644 --- a/src/data_designer/config/sampler_params.py +++ b/src/data_designer/config/sampler_params.py @@ -2,11 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 from enum import Enum -from typing import Literal, TypeAlias +from typing import Literal, Optional, Union import pandas as pd from pydantic import Field, field_validator, model_validator -from typing_extensions import Self +from typing_extensions import Self, TypeAlias from .base import ConfigBase from .utils.constants import ( @@ -41,12 +41,12 @@ class SamplerType(str, Enum): class CategorySamplerParams(ConfigBase): - values: list[str | int | float] = Field( + values: list[Union[str, int, float]] = Field( ..., min_length=1, description="List of possible categorical values that can be sampled from.", ) - weights: list[float] | None = Field( + weights: Optional[list[float]] = Field( default=None, description=( "List of unnormalized probability weights to assigned to each value, in order. " @@ -87,7 +87,7 @@ def _validate_param_is_datetime(cls, value: str) -> str: class SubcategorySamplerParams(ConfigBase): category: str = Field(..., description="Name of parent category to this subcategory.") - values: dict[str, list[str | int | float]] = Field( + values: dict[str, list[Union[str, int, float]]] = Field( ..., description="Mapping from each value of parent category to a list of subcategory values.", ) @@ -127,7 +127,7 @@ def _validate_min_less_than_max(self) -> Self: class UUIDSamplerParams(ConfigBase): - prefix: str | None = Field(default=None, description="String prepended to the front of the UUID.") + prefix: Optional[str] = Field(default=None, description="String prepended to the front of the UUID.") short_form: bool = Field( default=False, description="If true, all UUIDs sampled will be truncated at 8 characters.", @@ -153,7 +153,7 @@ class ScipySamplerParams(ConfigBase): ..., description="Parameters of the scipy.stats distribution given in `dist_name`.", ) - decimal_places: int | None = Field( + decimal_places: Optional[int] = Field( default=None, description="Number of decimal places to round the sampled values to." ) @@ -191,7 +191,7 @@ class BernoulliMixtureSamplerParams(ConfigBase): class GaussianSamplerParams(ConfigBase): mean: float = Field(..., description="Mean of the Gaussian distribution") stddev: float = Field(..., description="Standard deviation of the Gaussian distribution") - decimal_places: int | None = Field( + decimal_places: Optional[int] = Field( default=None, description="Number of decimal places to round the sampled values to." ) @@ -203,7 +203,7 @@ class PoissonSamplerParams(ConfigBase): class UniformSamplerParams(ConfigBase): low: float = Field(..., description="Lower bound of the uniform distribution, inclusive.") high: float = Field(..., description="Upper bound of the uniform distribution, inclusive.") - decimal_places: int | None = Field( + decimal_places: Optional[int] = Field( default=None, description="Number of decimal places to round the sampled values to." ) @@ -223,11 +223,11 @@ class PersonSamplerParams(ConfigBase): "that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR, ..." ), ) - sex: SexT | None = Field( + sex: Optional[SexT] = Field( default=None, description="If specified, then only synthetic people of the specified sex will be sampled.", ) - city: str | list[str] | None = Field( + city: Optional[Union[str, list[str]]] = Field( default=None, description="If specified, then only synthetic people from these cities will be sampled.", ) @@ -238,7 +238,7 @@ class PersonSamplerParams(ConfigBase): max_length=2, ) - state: str | list[str] | None = Field( + state: Optional[Union[str, list[str]]] = Field( default=None, description=( "Only supported for 'en_US' locale. If specified, then only synthetic people " @@ -265,7 +265,8 @@ def generator_kwargs(self) -> list[str]: def people_gen_key(self) -> str: if self.locale in LOCALES_WITH_MANAGED_DATASETS and self.sample_dataset_when_available: return f"{self.locale}_with_personas" if self.with_synthetic_personas else self.locale - return f"{self.locale}_faker" + else: + return f"{self.locale}_faker" @field_validator("age_range") @classmethod @@ -321,21 +322,21 @@ def _validate_with_synthetic_personas(self) -> Self: return self -SamplerParamsT: TypeAlias = ( - SubcategorySamplerParams - | CategorySamplerParams - | DatetimeSamplerParams - | PersonSamplerParams - | TimeDeltaSamplerParams - | UUIDSamplerParams - | BernoulliSamplerParams - | BernoulliMixtureSamplerParams - | BinomialSamplerParams - | GaussianSamplerParams - | PoissonSamplerParams - | UniformSamplerParams - | ScipySamplerParams -) +SamplerParamsT: TypeAlias = Union[ + SubcategorySamplerParams, + CategorySamplerParams, + DatetimeSamplerParams, + PersonSamplerParams, + TimeDeltaSamplerParams, + UUIDSamplerParams, + BernoulliSamplerParams, + BernoulliMixtureSamplerParams, + BinomialSamplerParams, + GaussianSamplerParams, + PoissonSamplerParams, + UniformSamplerParams, + ScipySamplerParams, +] def is_numerical_sampler_type(sampler_type: SamplerType) -> bool: diff --git a/src/data_designer/config/utils/__init__.py b/src/data_designer/config/utils/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/config/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/config/utils/code_lang.py b/src/data_designer/config/utils/code_lang.py index 4f1af4c8..c0621d36 100644 --- a/src/data_designer/config/utils/code_lang.py +++ b/src/data_designer/config/utils/code_lang.py @@ -4,6 +4,7 @@ from __future__ import annotations from enum import Enum +from typing import Union class CodeLang(str, Enum): @@ -25,17 +26,17 @@ class CodeLang(str, Enum): SQL_ANSI = "sql:ansi" @staticmethod - def parse(value: str | CodeLang) -> tuple[str, str | None]: + def parse(value: Union[str, CodeLang]) -> tuple[str, Union[str, None]]: value = value.value if isinstance(value, CodeLang) else value split_vals = value.split(":") return (split_vals[0], split_vals[1] if len(split_vals) > 1 else None) @staticmethod - def parse_lang(value: str | CodeLang) -> str: + def parse_lang(value: Union[str, CodeLang]) -> str: return CodeLang.parse(value)[0] @staticmethod - def parse_dialect(value: str | CodeLang) -> str | None: + def parse_dialect(value: Union[str, CodeLang]) -> Union[str, None]: return CodeLang.parse(value)[1] @staticmethod @@ -57,7 +58,7 @@ def supported_values() -> set[str]: ########################################################## -def code_lang_to_syntax_lexer(code_lang: CodeLang | str) -> str: +def code_lang_to_syntax_lexer(code_lang: Union[CodeLang, str]) -> str: """Convert the code language to a syntax lexer for Pygments. Reference: https://pygments.org/docs/lexers/ diff --git a/src/data_designer/config/utils/constants.py b/src/data_designer/config/utils/constants.py index 59958991..7dd6f41a 100644 --- a/src/data_designer/config/utils/constants.py +++ b/src/data_designer/config/utils/constants.py @@ -11,7 +11,7 @@ DEFAULT_REPR_HTML_STYLE = "nord" REPR_HTML_FIXED_WIDTH = 1000 -REPR_HTML_TEMPLATE = f""" +REPR_HTML_TEMPLATE = """ {{highlighted_html}} -""" +""".format(fixed_width=REPR_HTML_FIXED_WIDTH) class NordColor(Enum): diff --git a/src/data_designer/config/utils/io_helpers.py b/src/data_designer/config/utils/io_helpers.py index 96b7a9ff..b47d35c9 100644 --- a/src/data_designer/config/utils/io_helpers.py +++ b/src/data_designer/config/utils/io_helpers.py @@ -8,7 +8,7 @@ from numbers import Number import os from pathlib import Path -from typing import Any +from typing import Any, Union import numpy as np import pandas as pd @@ -39,7 +39,8 @@ def read_parquet_dataset(path: Path) -> pd.DataFrame: [pd.read_parquet(file, dtype_backend="pyarrow") for file in sorted(path.glob("*.parquet"))], ignore_index=True, ) - raise e + else: + raise e def write_seed_dataset(dataframe: pd.DataFrame, file_path: Path) -> None: @@ -62,7 +63,7 @@ def write_seed_dataset(dataframe: pd.DataFrame, file_path: Path) -> None: dataframe.to_json(file_path, orient="records", lines=True) -def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True) -> Path: +def validate_dataset_file_path(file_path: Union[str, Path], should_exist: bool = True) -> Path: """Validate that a dataset file path has a valid extension and optionally exists. Args: @@ -82,7 +83,7 @@ def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True) return file_path -def smart_load_dataframe(dataframe: str | Path | pd.DataFrame) -> pd.DataFrame: +def smart_load_dataframe(dataframe: Union[str, Path, pd.DataFrame]) -> pd.DataFrame: """Load a dataframe from file if a path is given, otherwise return the dataframe. Args: @@ -106,14 +107,15 @@ def smart_load_dataframe(dataframe: str | Path | pd.DataFrame) -> pd.DataFrame: # Load the dataframe based on the file extension. if ext == "csv": return pd.read_csv(dataframe) - if ext == "json": + elif ext == "json": return pd.read_json(dataframe, lines=True) - if ext == "parquet": + elif ext == "parquet": return pd.read_parquet(dataframe) - raise ValueError(f"Unsupported file format: {dataframe}") + else: + raise ValueError(f"Unsupported file format: {dataframe}") -def smart_load_yaml(yaml_in: str | Path | dict) -> dict: +def smart_load_yaml(yaml_in: Union[str, Path, dict]) -> dict: """Return the yaml config as a dict given flexible input types. Args: @@ -130,7 +132,8 @@ def smart_load_yaml(yaml_in: str | Path | dict) -> dict: elif isinstance(yaml_in, str): if yaml_in.endswith((".yaml", ".yml")) and not os.path.isfile(yaml_in): raise FileNotFoundError(f"File not found: {yaml_in}") - yaml_out = yaml.safe_load(yaml_in) + else: + yaml_out = yaml.safe_load(yaml_in) else: raise ValueError( f"'{yaml_in}' is an invalid yaml config format. Valid options are: dict, yaml string, or yaml file path." @@ -142,14 +145,17 @@ def smart_load_yaml(yaml_in: str | Path | dict) -> dict: return yaml_out -def serialize_data(data: dict | list | str | Number, **kwargs) -> str: - if isinstance(data, dict) or isinstance(data, list): +def serialize_data(data: Union[dict, list, str, Number], **kwargs) -> str: + if isinstance(data, dict): + return json.dumps(data, ensure_ascii=False, default=_convert_to_serializable, **kwargs) + elif isinstance(data, list): return json.dumps(data, ensure_ascii=False, default=_convert_to_serializable, **kwargs) - if isinstance(data, str): + elif isinstance(data, str): return data - if isinstance(data, Number): + elif isinstance(data, Number): return str(data) - raise ValueError(f"Invalid data type: {type(data)}") + else: + raise ValueError(f"Invalid data type: {type(data)}") def _convert_to_serializable(obj: Any) -> Any: diff --git a/src/data_designer/config/utils/misc.py b/src/data_designer/config/utils/misc.py index 05e5609a..c6b55d29 100644 --- a/src/data_designer/config/utils/misc.py +++ b/src/data_designer/config/utils/misc.py @@ -5,6 +5,7 @@ from contextlib import contextmanager import json +from typing import Optional, Union from jinja2 import TemplateSyntaxError, meta from jinja2.sandbox import ImmutableSandboxedEnvironment @@ -57,7 +58,9 @@ def get_prompt_template_keywords(template: str) -> set[str]: return keywords -def json_indent_list_of_strings(column_names: list[str], *, indent: int | str | None = None) -> list[str] | str | None: +def json_indent_list_of_strings( + column_names: list[str], *, indent: Optional[Union[int, str]] = None +) -> Optional[Union[list[str], str]]: """Convert a list of column names to a JSON string if the list is long. This function helps keep Data Designer's __repr__ output clean and readable. diff --git a/src/data_designer/config/utils/numerical_helpers.py b/src/data_designer/config/utils/numerical_helpers.py index c4f8a197..7d227cd3 100644 --- a/src/data_designer/config/utils/numerical_helpers.py +++ b/src/data_designer/config/utils/numerical_helpers.py @@ -3,7 +3,7 @@ import numbers from numbers import Number -from typing import Any +from typing import Any, Type from .constants import REPORTING_PRECISION @@ -18,7 +18,7 @@ def is_float(val: Any) -> bool: def prepare_number_for_reporting( value: Number, - target_type: type[Number], + target_type: Type[Number], precision: int = REPORTING_PRECISION, ) -> Number: """Ensure native python types and round to `precision` decimal digits.""" diff --git a/src/data_designer/config/utils/type_helpers.py b/src/data_designer/config/utils/type_helpers.py index 847e1f4c..02b17bb6 100644 --- a/src/data_designer/config/utils/type_helpers.py +++ b/src/data_designer/config/utils/type_helpers.py @@ -3,7 +3,7 @@ from enum import Enum import inspect -from typing import Any +from typing import Any, Type from pydantic import BaseModel @@ -11,7 +11,7 @@ from .errors import InvalidEnumValueError -def get_sampler_params() -> dict[str, type[BaseModel]]: +def get_sampler_params() -> dict[str, Type[BaseModel]]: """Returns a dictionary of sampler parameter classes.""" params_cls_list = [ params_cls @@ -38,7 +38,7 @@ def get_sampler_params() -> dict[str, type[BaseModel]]: return params_cls_dict -def resolve_string_enum(enum_instance: Any, enum_type: type[Enum]) -> Enum: +def resolve_string_enum(enum_instance: Any, enum_type: Type[Enum]) -> Enum: if not issubclass(enum_type, Enum): raise InvalidEnumValueError(f"🛑 `enum_type` must be a subclass of Enum. You provided: {enum_type}") invalid_enum_value_error = InvalidEnumValueError( @@ -47,7 +47,7 @@ def resolve_string_enum(enum_instance: Any, enum_type: type[Enum]) -> Enum: ) if isinstance(enum_instance, enum_type): return enum_instance - if isinstance(enum_instance, str): + elif isinstance(enum_instance, str): try: return enum_type(enum_instance) except ValueError: diff --git a/src/data_designer/config/utils/validation.py b/src/data_designer/config/utils/validation.py index a3328637..a3864d31 100644 --- a/src/data_designer/config/utils/validation.py +++ b/src/data_designer/config/utils/validation.py @@ -5,6 +5,7 @@ from enum import Enum from string import Formatter +from typing import Optional from jinja2 import meta from jinja2.sandbox import ImmutableSandboxedEnvironment @@ -15,6 +16,7 @@ from rich.panel import Panel from ..columns import ColumnConfigT, DataDesignerColumnType +from ..processors import ProcessorConfig, ProcessorType from ..validator_params import ValidatorType from .constants import RICH_CONSOLE_THEME from .misc import can_run_data_designer_locally @@ -28,6 +30,7 @@ class ViolationType(str, Enum): EXPRESSION_REFERENCE_MISSING = "expression_reference_missing" F_STRING_SYNTAX = "f_string_syntax" LOCAL_ONLY_COLUMN = "local_only_column" + INVALID_COLUMN = "invalid_column" INVALID_MODEL_CONFIG = "invalid_model_config" INVALID_REFERENCE = "invalid_reference" PROMPT_WITHOUT_REFERENCES = "prompt_without_references" @@ -39,7 +42,7 @@ class ViolationLevel(str, Enum): class Violation(BaseModel): - column: str | None = None + column: Optional[str] = None type: ViolationType message: str level: ViolationLevel @@ -51,6 +54,7 @@ def has_column(self) -> bool: def validate_data_designer_config( columns: list[ColumnConfigT], + processor_configs: list[ProcessorConfig], allowed_references: list[str], ) -> list[Violation]: violations = [] @@ -58,6 +62,7 @@ def validate_data_designer_config( violations.extend(validate_code_validation(columns=columns)) violations.extend(validate_expression_references(columns=columns, allowed_references=allowed_references)) violations.extend(validate_columns_not_all_dropped(columns=columns)) + violations.extend(validate_drop_columns_processor(columns=columns, processor_configs=processor_configs)) if not can_run_data_designer_locally(): violations.extend(validate_local_only_columns(columns=columns)) return violations @@ -147,7 +152,7 @@ def validate_prompt_templates( if ( prompt_type == "prompt" and len(prompt_references) == 0 - and (not hasattr(column, "multi_modal_context") or column.multi_modal_context is None) + and (not hasattr(column, "multi_modal_context") or getattr(column, "multi_modal_context") is None) ): message = ( f"The {prompt_type} template for '{column.name}' does not reference any columns. " @@ -262,6 +267,27 @@ def validate_columns_not_all_dropped( return [] +def validate_drop_columns_processor( + columns: list[ColumnConfigT], + processor_configs: list[ProcessorConfig], +) -> list[Violation]: + all_column_names = set([c.name for c in columns]) + for processor_config in processor_configs: + if processor_config.processor_type == ProcessorType.DROP_COLUMNS: + invalid_columns = set(processor_config.column_names) - all_column_names + if len(invalid_columns) > 0: + return [ + Violation( + column=c, + type=ViolationType.INVALID_COLUMN, + message=f"Drop columns processor is configured to drop column '{c!r}', but the column is not defined.", + level=ViolationLevel.ERROR, + ) + for c in invalid_columns + ] + return [] + + def validate_expression_references( columns: list[ColumnConfigT], allowed_references: list[str], diff --git a/src/data_designer/config/utils/visualization.py b/src/data_designer/config/utils/visualization.py index fbba948c..f245517c 100644 --- a/src/data_designer/config/utils/visualization.py +++ b/src/data_designer/config/utils/visualization.py @@ -7,7 +7,7 @@ from enum import Enum from functools import cached_property import json -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Union import numpy as np import pandas as pd @@ -51,22 +51,23 @@ class WithRecordSamplerMixin: def _record_sampler_dataset(self) -> pd.DataFrame: if hasattr(self, "dataset") and self.dataset is not None and isinstance(self.dataset, pd.DataFrame): return self.dataset - if ( + elif ( hasattr(self, "load_dataset") and callable(self.load_dataset) and (dataset := self.load_dataset()) is not None and isinstance(dataset, pd.DataFrame) ): return dataset - raise DatasetSampleDisplayError("No valid dataset found in results object.") + else: + raise DatasetSampleDisplayError("No valid dataset found in results object.") def display_sample_record( self, - index: int | None = None, + index: Optional[int] = None, *, hide_seed_columns: bool = False, syntax_highlighting_theme: str = "dracula", - background_color: str | None = None, + background_color: Optional[str] = None, ) -> None: """Display a sample record from the Data Designer dataset preview. @@ -100,11 +101,11 @@ def display_sample_record( def create_rich_histogram_table( - data: dict[str, int | float], + data: dict[str, Union[int, float]], column_names: tuple[int, int], name_style: str = ColorPalette.BLUE.value, value_style: str = ColorPalette.TEAL.value, - title: str | None = None, + title: Optional[str] = None, **kwargs, ) -> Table: table = Table(title=title, **kwargs) @@ -120,11 +121,11 @@ def create_rich_histogram_table( def display_sample_record( - record: dict | pd.Series | pd.DataFrame, + record: Union[dict, pd.Series, pd.DataFrame], config_builder: DataDesignerConfigBuilder, - background_color: str | None = None, + background_color: Optional[str] = None, syntax_highlighting_theme: str = "dracula", - record_index: int | None = None, + record_index: Optional[int] = None, hide_seed_columns: bool = False, ): if isinstance(record, (dict, pd.Series)): @@ -227,7 +228,7 @@ def display_sample_record( def display_sampler_table( sampler_params: dict[SamplerType, ConfigBase], - title: str | None = None, + title: Optional[str] = None, ) -> None: table = Table(expand=True) table.add_column("Type") @@ -285,7 +286,7 @@ def _get_field_type(field: dict) -> str: return field["type"] # union type - if "anyOf" in field: + elif "anyOf" in field: types = [] for f in field["anyOf"]: if "$ref" in f: diff --git a/src/data_designer/config/validator_params.py b/src/data_designer/config/validator_params.py index c6b5e41e..0dedbda4 100644 --- a/src/data_designer/config/validator_params.py +++ b/src/data_designer/config/validator_params.py @@ -2,10 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 from enum import Enum -from typing import Any, TypeAlias +from typing import Any, Optional, Union from pydantic import Field, field_serializer, model_validator -from typing_extensions import Self +from typing_extensions import Self, TypeAlias from .base import ConfigBase from .utils.code_lang import SQL_DIALECTS, CodeLang @@ -35,7 +35,7 @@ class LocalCallableValidatorParams(ConfigBase): validation_function: Any = Field( description="Function (Callable[[pd.DataFrame], pd.DataFrame]) to validate the data" ) - output_schema: dict[str, Any] | None = Field( + output_schema: Optional[dict[str, Any]] = Field( default=None, description="Expected schema for local callable validator's output" ) @@ -52,7 +52,7 @@ def validate_validation_function(self) -> Self: class RemoteValidatorParams(ConfigBase): endpoint_url: str = Field(description="URL of the remote endpoint") - output_schema: dict[str, Any] | None = Field( + output_schema: Optional[dict[str, Any]] = Field( default=None, description="Expected schema for remote validator's output" ) timeout: float = Field(default=30.0, gt=0, description="The timeout for the HTTP request") @@ -61,4 +61,8 @@ class RemoteValidatorParams(ConfigBase): max_parallel_requests: int = Field(default=4, ge=1, description="The maximum number of parallel requests to make") -ValidatorParamsT: TypeAlias = CodeValidatorParams | LocalCallableValidatorParams | RemoteValidatorParams +ValidatorParamsT: TypeAlias = Union[ + CodeValidatorParams, + LocalCallableValidatorParams, + RemoteValidatorParams, +] diff --git a/src/data_designer/engine/analysis/__init__.py b/src/data_designer/engine/analysis/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/engine/analysis/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/engine/analysis/column_profilers/__init__.py b/src/data_designer/engine/analysis/column_profilers/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/engine/analysis/column_profilers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py b/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py index 911000a0..37de5dbd 100644 --- a/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py +++ b/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py @@ -5,6 +5,7 @@ import logging import random +from typing import Union from data_designer.config.analysis.column_profilers import ( JudgeScoreProfilerConfig, @@ -95,7 +96,7 @@ def _summarize_score_sample( name: str, sample: list[JudgeScoreSample], histogram: CategoricalHistogramData, - distribution: CategoricalDistribution | NumericalDistribution | MissingValue, + distribution: Union[CategoricalDistribution, NumericalDistribution, MissingValue], distribution_type: ColumnDistributionType, ) -> JudgeScoreSummary: if isinstance(distribution, MissingValue) or not sample: @@ -107,7 +108,7 @@ def _summarize_score_sample( category_info = [] total_count = sum(histogram.counts) - for cat, count in zip(histogram.categories, histogram.counts, strict=False): + for cat, count in zip(histogram.categories, histogram.counts): percentage = (count / total_count) * 100 category_info.append(f"{cat}: {count} records ({percentage:.1f}%)") diff --git a/src/data_designer/engine/analysis/column_statistics.py b/src/data_designer/engine/analysis/column_statistics.py index c01a3737..dd4fb1e9 100644 --- a/src/data_designer/engine/analysis/column_statistics.py +++ b/src/data_designer/engine/analysis/column_statistics.py @@ -4,7 +4,7 @@ from __future__ import annotations import logging -from typing import Any, TypeAlias +from typing import Any, Type, TypeAlias, Union import pandas as pd from pydantic import BaseModel @@ -49,7 +49,7 @@ def df(self) -> pd.DataFrame: return self.column_config_with_df.df @property - def column_statistics_type(self) -> type[ColumnStatisticsT]: + def column_statistics_type(self) -> Type[ColumnStatisticsT]: return DEFAULT_COLUMN_STATISTICS_MAP.get(self.column_config.column_type, GeneralColumnStatistics) def calculate(self) -> Self: @@ -146,17 +146,17 @@ class ExpressionColumnStatisticsCalculator(GeneralColumnStatisticsCalculator): . } -ColumnStatisticsCalculatorT: TypeAlias = ( - ExpressionColumnStatisticsCalculator - | ValidationColumnStatisticsCalculator - | GeneralColumnStatisticsCalculator - | LLMCodeColumnStatisticsCalculator - | LLMJudgedColumnStatisticsCalculator - | LLMStructuredColumnStatisticsCalculator - | LLMTextColumnStatisticsCalculator - | SamplerColumnStatisticsCalculator - | SeedDatasetColumnStatisticsCalculator -) +ColumnStatisticsCalculatorT: TypeAlias = Union[ + ExpressionColumnStatisticsCalculator, + ValidationColumnStatisticsCalculator, + GeneralColumnStatisticsCalculator, + LLMCodeColumnStatisticsCalculator, + LLMJudgedColumnStatisticsCalculator, + LLMStructuredColumnStatisticsCalculator, + LLMTextColumnStatisticsCalculator, + SamplerColumnStatisticsCalculator, + SeedDatasetColumnStatisticsCalculator, +] DEFAULT_COLUMN_STATISTICS_CALCULATOR_MAP = { DataDesignerColumnType.EXPRESSION: ExpressionColumnStatisticsCalculator, DataDesignerColumnType.VALIDATION: ValidationColumnStatisticsCalculator, diff --git a/src/data_designer/engine/analysis/reporting/__init__.py b/src/data_designer/engine/analysis/reporting/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/engine/analysis/reporting/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/engine/analysis/utils/__init__.py b/src/data_designer/engine/analysis/utils/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/engine/analysis/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/engine/analysis/utils/judge_score_processing.py b/src/data_designer/engine/analysis/utils/judge_score_processing.py index 8190e9f6..95b01ccc 100644 --- a/src/data_designer/engine/analysis/utils/judge_score_processing.py +++ b/src/data_designer/engine/analysis/utils/judge_score_processing.py @@ -3,7 +3,7 @@ from collections import defaultdict import logging -from typing import Any +from typing import Any, Optional, Union import pandas as pd @@ -21,7 +21,7 @@ def extract_judge_score_distributions( column_config: LLMJudgeColumnConfig, df: pd.DataFrame -) -> JudgeScoreDistributions | MissingValue: +) -> Union[JudgeScoreDistributions, MissingValue]: scores = defaultdict(list) reasoning = defaultdict(list) @@ -79,10 +79,10 @@ def extract_judge_score_distributions( def sample_scores_and_reasoning( - scores: list[int | str], + scores: list[Union[int, str]], reasoning: list[str], num_samples: int, - random_seed: int | None = None, + random_seed: Optional[int] = None, ) -> list[JudgeScoreSample]: if len(scores) != len(reasoning): raise ValueError("scores and reasoning must have the same length") @@ -96,10 +96,7 @@ def sample_scores_and_reasoning( df_samples = pd.DataFrame({"score": scores, "reasoning": reasoning}) if len(scores) <= num_samples: - return [ - JudgeScoreSample(score=score, reasoning=reasoning) - for score, reasoning in zip(scores, reasoning, strict=False) - ] + return [JudgeScoreSample(score=score, reasoning=reasoning) for score, reasoning in zip(scores, reasoning)] # Sample maintaining original proportions from each category (int or str) # Calculate the frequency of each score category diff --git a/src/data_designer/engine/column_generators/generators/expression.py b/src/data_designer/engine/column_generators/generators/expression.py index 2f465918..7da80e66 100644 --- a/src/data_designer/engine/column_generators/generators/expression.py +++ b/src/data_designer/engine/column_generators/generators/expression.py @@ -50,11 +50,11 @@ def generate(self, data: pd.DataFrame) -> pd.DataFrame: def _cast_type(self, value: str) -> str | float | int | bool: if self.config.dtype == "str": return value - if self.config.dtype == "float": + elif self.config.dtype == "float": return float(value) - if self.config.dtype == "int": + elif self.config.dtype == "int": return int(float(value)) - if self.config.dtype == "bool": + elif self.config.dtype == "bool": try: return bool(int(float(value))) except ValueError: diff --git a/src/data_designer/engine/column_generators/generators/samplers.py b/src/data_designer/engine/column_generators/generators/samplers.py index 3f274292..d28e0a61 100644 --- a/src/data_designer/engine/column_generators/generators/samplers.py +++ b/src/data_designer/engine/column_generators/generators/samplers.py @@ -1,10 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from collections.abc import Callable from functools import partial import logging import random +from typing import Callable import pandas as pd diff --git a/src/data_designer/engine/column_generators/generators/validation.py b/src/data_designer/engine/column_generators/generators/validation.py index 504a2ab6..424165fc 100644 --- a/src/data_designer/engine/column_generators/generators/validation.py +++ b/src/data_designer/engine/column_generators/generators/validation.py @@ -35,7 +35,7 @@ def get_validator_from_params(validator_type: ValidatorType, validator_params: V if validator_type == ValidatorType.CODE: if validator_params.code_lang == CodeLang.PYTHON: return PythonValidator(validator_params) - if validator_params.code_lang in SQL_DIALECTS: + elif validator_params.code_lang in SQL_DIALECTS: return SQLValidator(validator_params) elif validator_type == ValidatorType.REMOTE: return RemoteValidator(validator_params) diff --git a/src/data_designer/engine/column_generators/utils/__init__.py b/src/data_designer/engine/column_generators/utils/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/engine/column_generators/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/engine/column_generators/utils/judge_score_factory.py b/src/data_designer/engine/column_generators/utils/judge_score_factory.py index 40212dbd..b4d458c6 100644 --- a/src/data_designer/engine/column_generators/utils/judge_score_factory.py +++ b/src/data_designer/engine/column_generators/utils/judge_score_factory.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from enum import Enum +from typing import Type from pydantic import BaseModel, ConfigDict, Field, create_model @@ -18,7 +19,7 @@ class BaseJudgeResponse(BaseModel): reasoning: str = Field(..., description="Reasoning for the assigned score.") -def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str: +def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str: """Convert score descriptions into a single text block.""" list_block = "\n".join( [SCORING_FORMAT.format(score=score, description=description) for score, description in options.items()] @@ -26,7 +27,7 @@ def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str: return SCORE_FIELD_DESCRIPTION_FORMAT.format(enum_name=enum_type.__name__, scoring=list_block) -def create_judge_response_model(score: Score) -> type[BaseJudgeResponse]: +def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]: """Create a JudgeResponse data type.""" enum_members = {} for option in score.options.keys(): @@ -45,8 +46,8 @@ def create_judge_response_model(score: Score) -> type[BaseJudgeResponse]: def create_judge_structured_output_model( - judge_responses: list[type[BaseJudgeResponse]], -) -> type[BaseModel]: + judge_responses: list[Type[BaseJudgeResponse]], +) -> Type[BaseModel]: """Create a JudgeStructuredOutput class dynamically.""" return create_model( "JudgeStructuredOutput", diff --git a/src/data_designer/engine/configurable_task.py b/src/data_designer/engine/configurable_task.py index 94d8f799..0c3f10a5 100644 --- a/src/data_designer/engine/configurable_task.py +++ b/src/data_designer/engine/configurable_task.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Generic, TypeVar +from typing import Generic, Type, TypeVar import pandas as pd @@ -30,7 +30,7 @@ def __init__(self, config: TaskConfigT, *, resource_provider: ResourceProvider | self._initialize() @classmethod - def get_config_type(cls) -> type[TaskConfigT]: + def get_config_type(cls) -> Type[TaskConfigT]: for base in cls.__orig_bases__: if hasattr(base, "__args__") and len(base.__args__) == 1 and issubclass(base.__args__[0], ConfigBase): return base.__args__[0] diff --git a/src/data_designer/engine/dataset_builders/__init__.py b/src/data_designer/engine/dataset_builders/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/engine/dataset_builders/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/engine/dataset_builders/artifact_storage.py b/src/data_designer/engine/dataset_builders/artifact_storage.py index c29f4591..4a1e4919 100644 --- a/src/data_designer/engine/dataset_builders/artifact_storage.py +++ b/src/data_designer/engine/dataset_builders/artifact_storage.py @@ -6,6 +6,7 @@ import logging from pathlib import Path import shutil +from typing import Union import pandas as pd from pydantic import BaseModel, field_validator, model_validator @@ -57,7 +58,7 @@ def partial_results_path(self) -> Path: return self.base_dataset_path / self.partial_results_folder_name @field_validator("artifact_path") - def validate_artifact_path(cls, v: Path | str) -> Path: + def validate_artifact_path(cls, v: Union[Path, str]) -> Path: v = Path(v) if not v.is_dir(): raise ArtifactStorageError("Artifact path must exist and be a directory") diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py index 430a56ea..f16a5a96 100644 --- a/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -1,20 +1,26 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from collections.abc import Callable import functools import json import logging from pathlib import Path import time +from typing import Callable import pandas as pd from data_designer.config.columns import ColumnConfigT +from data_designer.config.dataset_builders import BuildStage +from data_designer.config.processors import ( + DropColumnsProcessorConfig, + ProcessorConfig, + ProcessorType, +) from data_designer.engine.column_generators.generators.base import ColumnGenerator, GenerationStrategy from data_designer.engine.column_generators.generators.llm_generators import WithLLMGeneration -from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage, BatchStage -from data_designer.engine.dataset_builders.errors import DatasetGenerationError +from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage +from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError from data_designer.engine.dataset_builders.multi_column_configs import ( DatasetBuilderColumnConfigT, MultiColumnConfig, @@ -26,7 +32,7 @@ from data_designer.engine.dataset_builders.utils.dataset_batch_manager import ( DatasetBatchManager, ) -from data_designer.engine.processing.processors.configs import DropColumnsProcessorConfig +from data_designer.engine.processing.processors.base import Processor from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry from data_designer.engine.resources.resource_provider import ResourceProvider @@ -38,6 +44,7 @@ class ColumnWiseDatasetBuilder: def __init__( self, column_configs: list[DatasetBuilderColumnConfigT], + processor_configs: list[ProcessorConfig], resource_provider: ResourceProvider, registry: DataDesignerRegistry | None = None, ): @@ -46,6 +53,7 @@ def __init__( self._records_to_drop: set[int] = set() self._registry = registry or DataDesignerRegistry() self._column_configs = column_configs + self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors(processor_configs) self._validate_column_configs() @property @@ -79,8 +87,12 @@ def build( for batch_idx in range(1, self.batch_manager.num_batches + 1): logger.info(f"⏳ Processing batch {batch_idx} of {self.batch_manager.num_batches}") self._run_batch(generators) - df_batch = self.batch_manager.get_current_batch(as_dataframe=True) - self._write_processed_batch(self.drop_columns_if_needed(df_batch, save_dropped_columns=True)) + df_batch = self._run_processors( + stage=BuildStage.POST_BATCH, + dataframe=self.batch_manager.get_current_batch(as_dataframe=True), + current_batch_number=batch_idx, + ) + self._write_processed_batch(df_batch) self.batch_manager.finish_batch(on_batch_complete) self.batch_manager.finish() @@ -99,7 +111,11 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame: start_time = time.perf_counter() self.batch_manager.start(num_records=num_records, buffer_size=num_records) self._run_batch(generators, save_partial_results=False) - dataset = self.batch_manager.get_current_batch(as_dataframe=True) + dataset = self._run_processors( + stage=BuildStage.POST_BATCH, + dataframe=self.batch_manager.get_current_batch(as_dataframe=True), + current_batch_number=None, # preview mode does not have a batch number + ) self.batch_manager.reset() model_usage_stats = self._resource_provider.model_registry.get_model_usage_stats( @@ -109,29 +125,6 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame: return dataset - def drop_columns_if_needed(self, dataframe: pd.DataFrame, *, save_dropped_columns: bool = False) -> pd.DataFrame: - if len(columns_to_drop := [config.name for config in self.single_column_configs if config.drop]) == 0: - return dataframe - try: - dropped_column_parquet_file_name = ( - None - if not save_dropped_columns - else self.artifact_storage.create_batch_file_path( - batch_number=self.batch_manager.get_current_batch_number(), - batch_stage=BatchStage.DROPPED_COLUMNS, - ).name - ) - df = DropColumnsProcessor( - config=DropColumnsProcessorConfig( - column_names=columns_to_drop, - dropped_column_parquet_file_name=dropped_column_parquet_file_name, - ), - resource_provider=self._resource_provider, - ).process(dataframe) - return df - except Exception as e: - raise DatasetGenerationError(f"🛑 Failed to drop columns {columns_to_drop}: {e}") - def _initialize_generators(self) -> list[ColumnGenerator]: return [ self._registry.column_generators.get_for_config_type(type(config))( @@ -218,6 +211,51 @@ def _validate_column_configs(self) -> None: ).can_generate_from_scratch: raise DatasetGenerationError("🛑 The first column config must be a from-scratch column generator.") + def _initialize_processors(self, processor_configs: list[ProcessorConfig]) -> dict[BuildStage, list[Processor]]: + # Check columns marked for drop + columns_to_drop = [config.name for config in self.single_column_configs if config.drop] + + processors: dict[BuildStage, list[Processor]] = {stage: [] for stage in BuildStage} + for config in processor_configs: + processors[config.build_stage].append( + self._registry.processors.get_for_config_type(type(config))( + config=config, + resource_provider=self._resource_provider, + ) + ) + + # Manually included "drop columns" processor takes precedence (can e.g., pick stages other than post-batch) + if config.processor_type == ProcessorType.DROP_COLUMNS: + for column in config.column_names: + if column in columns_to_drop: + columns_to_drop.remove(column) + + # If there are still columns marked for drop, add the "drop columns" processor to drop them + if len(columns_to_drop) > 0: + processors[BuildStage.POST_BATCH].append( # as post-batch by default + DropColumnsProcessor( + config=DropColumnsProcessorConfig( + column_names=columns_to_drop, + build_stage=BuildStage.POST_BATCH, + ), + resource_provider=self._resource_provider, + ) + ) + + return processors + + def _run_processors( + self, stage: BuildStage, dataframe: pd.DataFrame, current_batch_number: int | None = None + ) -> pd.DataFrame: + for processor in self._processors[stage]: + try: + dataframe = processor.process(dataframe, current_batch_number=current_batch_number) + except Exception as e: + raise DatasetProcessingError( + f"🛑 Failed to process dataset with processor {processor.metadata().name} in stage {stage}: {e}" + ) from e + return dataframe + def _worker_error_callback(self, exc: Exception, *, context: dict | None = None) -> None: """If a worker fails, we can handle the exception here.""" logger.warning( diff --git a/src/data_designer/engine/dataset_builders/errors.py b/src/data_designer/engine/dataset_builders/errors.py index d99a387a..ee4b6138 100644 --- a/src/data_designer/engine/dataset_builders/errors.py +++ b/src/data_designer/engine/dataset_builders/errors.py @@ -8,3 +8,6 @@ class ArtifactStorageError(DataDesignerError): ... class DatasetGenerationError(DataDesignerError): ... + + +class DatasetProcessingError(DataDesignerError): ... diff --git a/src/data_designer/engine/dataset_builders/utils/concurrency.py b/src/data_designer/engine/dataset_builders/utils/concurrency.py index 3b305e06..b5760b6a 100644 --- a/src/data_designer/engine/dataset_builders/utils/concurrency.py +++ b/src/data_designer/engine/dataset_builders/utils/concurrency.py @@ -8,7 +8,7 @@ import json import logging from threading import Lock, Semaphore -from typing import Any, Protocol +from typing import Any, Optional, Protocol from pydantic import BaseModel, Field @@ -46,13 +46,13 @@ def is_error_rate_exceeded(self, window: int) -> bool: class CallbackWithContext(Protocol): """Executor callback functions must accept a context kw argument.""" - def __call__(self, result: Any, *, context: dict | None = None) -> Any: ... + def __call__(self, result: Any, *, context: Optional[dict] = None) -> Any: ... class ErrorCallbackWithContext(Protocol): """Error callbacks take the Exception instance and context.""" - def __call__(self, exc: Exception, *, context: dict | None = None) -> Any: ... + def __call__(self, exc: Exception, *, context: Optional[dict] = None) -> Any: ... class ConcurrentThreadExecutor: @@ -92,8 +92,8 @@ def __init__( *, max_workers: int, column_name: str, - result_callback: CallbackWithContext | None = None, - error_callback: ErrorCallbackWithContext | None = None, + result_callback: Optional[CallbackWithContext] = None, + error_callback: Optional[ErrorCallbackWithContext] = None, shutdown_error_rate: float = 0.50, shutdown_error_window: int = 10, ): @@ -136,7 +136,7 @@ def _raise_task_error(self): ) ) - def submit(self, fn, *args, context: dict | None = None, **kwargs) -> None: + def submit(self, fn, *args, context: Optional[dict] = None, **kwargs) -> None: if self._executor is None: raise RuntimeError("Executor is not initialized, this class should be used as a context manager.") diff --git a/src/data_designer/engine/dataset_builders/utils/config_compiler.py b/src/data_designer/engine/dataset_builders/utils/config_compiler.py index e302a99e..2a784212 100644 --- a/src/data_designer/engine/dataset_builders/utils/config_compiler.py +++ b/src/data_designer/engine/dataset_builders/utils/config_compiler.py @@ -3,6 +3,7 @@ from data_designer.config.columns import DataDesignerColumnType from data_designer.config.data_designer_config import DataDesignerConfig +from data_designer.config.processors import ProcessorConfig from data_designer.engine.dataset_builders.multi_column_configs import ( DatasetBuilderColumnConfigT, SamplerMultiColumnConfig, @@ -50,3 +51,9 @@ def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[D compiled_column_configs.extend(generated_column_configs) return compiled_column_configs + + +def compile_dataset_builder_processor_configs( + config: DataDesignerConfig, +) -> list[ProcessorConfig]: + return config.processors or [] diff --git a/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py b/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py index 58826da5..240a3a64 100644 --- a/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +++ b/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py @@ -1,10 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from collections.abc import Callable, Container, Iterator import logging from pathlib import Path import shutil +from typing import Callable, Container, Iterator import pandas as pd import pyarrow.parquet as pq diff --git a/src/data_designer/engine/model_provider.py b/src/data_designer/engine/model_provider.py index deceef7c..e25f2e50 100644 --- a/src/data_designer/engine/model_provider.py +++ b/src/data_designer/engine/model_provider.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from functools import cached_property +from typing import Self from pydantic import BaseModel, field_validator, model_validator -from typing_extensions import Self from data_designer.engine.errors import NoModelProvidersError, UnknownProviderError @@ -53,7 +53,7 @@ def check_default_exists(self) -> Self: raise ValueError(f"Specified default {self.default!r} not found in providers list") return self - def _get_default_provider_name(self) -> str: + def get_default_provider_name(self) -> str: return self.default or self.providers[0].name @cached_property @@ -62,7 +62,7 @@ def _providers_dict(self) -> dict[str, ModelProvider]: def get_provider(self, name: str | None) -> ModelProvider: if name is None: - name = self._get_default_provider_name() + name = self.get_default_provider_name() try: return self._providers_dict[name] diff --git a/src/data_designer/engine/models/errors.py b/src/data_designer/engine/models/errors.py index 242ddfbe..d725fa83 100644 --- a/src/data_designer/engine/models/errors.py +++ b/src/data_designer/engine/models/errors.py @@ -44,7 +44,8 @@ def get_exception_primary_cause(exception: BaseException) -> BaseException: """ if exception.__cause__ is None: return exception - return get_exception_primary_cause(exception.__cause__) + else: + return get_exception_primary_cause(exception.__cause__) class GenerationValidationFailureError(Exception): ... diff --git a/src/data_designer/engine/models/litellm_overrides.py b/src/data_designer/engine/models/litellm_overrides.py index 78f9acae..41208a0d 100644 --- a/src/data_designer/engine/models/litellm_overrides.py +++ b/src/data_designer/engine/models/litellm_overrides.py @@ -5,6 +5,7 @@ import random import threading +from typing import Optional, Union import httpx import litellm @@ -89,7 +90,7 @@ def __init__( self._initial_retry_after_s = initial_retry_after_s self._jitter_pct = jitter_pct - def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None: + def _extract_retry_delay_from_headers(self, e: Exception) -> Optional[Union[int, float]]: """ Most of this code logic was extracted directly from the parent `Router`'s `_time_to_sleep_before_retry` function. Our override @@ -98,7 +99,7 @@ def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None: return this info, we'll simply use that retry value returned here. """ - response_headers: httpx.Headers | None = None + response_headers: Optional[httpx.Headers] = None if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore response_headers = e.response.headers # type: ignore if hasattr(e, "litellm_response_headers"): @@ -109,7 +110,8 @@ def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None: # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says. if retry_after is not None and 0 < retry_after <= 60: return retry_after - return None + else: + return None @override def _time_to_sleep_before_retry( @@ -117,9 +119,9 @@ def _time_to_sleep_before_retry( e: Exception, remaining_retries: int, num_retries: int, - healthy_deployments: list | None = None, - all_deployments: list | None = None, - ) -> int | float: + healthy_deployments: Optional[list] = None, + all_deployments: Optional[list] = None, + ) -> Union[int, float]: """ Implements exponential backoff for retries. diff --git a/src/data_designer/engine/models/parsers/errors.py b/src/data_designer/engine/models/parsers/errors.py index 087e8a83..cf0fd411 100644 --- a/src/data_designer/engine/models/parsers/errors.py +++ b/src/data_designer/engine/models/parsers/errors.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from typing import Optional + class ParserException(Exception): """Identifies errors resulting from generic parser errors. @@ -10,7 +12,7 @@ class ParserException(Exception): attempted to parse. """ - source: str | None + source: Optional[str] @staticmethod def _log_format(source: str) -> str: @@ -22,7 +24,7 @@ def _log_format(source: str) -> str: # return f"{source}" return "" - def __init__(self, msg: str | None = None, source: str | None = None): + def __init__(self, msg: Optional[str] = None, source: Optional[str] = None): msg = "" if msg is None else msg.strip() if source is not None: diff --git a/src/data_designer/engine/models/parsers/parser.py b/src/data_designer/engine/models/parsers/parser.py index 3037978a..62acdfc2 100644 --- a/src/data_designer/engine/models/parsers/parser.py +++ b/src/data_designer/engine/models/parsers/parser.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from functools import reduce +from typing import Optional from lxml import etree from lxml.etree import _Element @@ -104,8 +105,8 @@ def __call__(self, element: _Element) -> CodeBlock: def __init__( self, - tag_parsers: dict[str, TagParser] | None = None, - postprocessors: list[PostProcessor] | None = None, + tag_parsers: Optional[dict[str, TagParser]] = None, + postprocessors: Optional[list[PostProcessor]] = None, ): """ Initializes the LLMResponseParser with optional tag parsers and post-processors. diff --git a/src/data_designer/engine/models/parsers/postprocessors.py b/src/data_designer/engine/models/parsers/postprocessors.py index 1cce5290..d7959505 100644 --- a/src/data_designer/engine/models/parsers/postprocessors.py +++ b/src/data_designer/engine/models/parsers/postprocessors.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Type import json_repair from pydantic import BaseModel, ValidationError @@ -59,12 +60,12 @@ def deserialize_json_code( class RealizePydanticTypes: - types: list[type[BaseModel]] + types: list[Type[BaseModel]] - def __init__(self, types: list[type[BaseModel]]): + def __init__(self, types: list[Type[BaseModel]]): self.types = types - def _fit_types(self, obj: dict) -> BaseModel | None: + def _fit_types(self, obj: dict) -> Optional[BaseModel]: final_obj = None for t in self.types: diff --git a/src/data_designer/engine/models/parsers/types.py b/src/data_designer/engine/models/parsers/types.py index 1fa66647..17be38a2 100644 --- a/src/data_designer/engine/models/parsers/types.py +++ b/src/data_designer/engine/models/parsers/types.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Protocol, runtime_checkable +from typing import Any, Optional, Protocol, Type, runtime_checkable from lxml.etree import _Element from pydantic import BaseModel, Field @@ -30,7 +30,7 @@ def tail(self, n: int) -> Self: out.parsed = out.parsed[-n:] return out - def filter(self, block_types: list[type[BaseModel]]) -> Self: + def filter(self, block_types: list[Type[BaseModel]]) -> Self: out = self.model_copy() out.parsed = [b for b in out.parsed if isinstance(b, tuple(block_types))] return out @@ -69,7 +69,7 @@ class TextBlock(BaseModel): class CodeBlock(BaseModel): code: str - code_lang: str | None = None + code_lang: Optional[str] = None class StructuredDataBlock(BaseModel): diff --git a/src/data_designer/engine/models/recipes/__init__.py b/src/data_designer/engine/models/recipes/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/engine/models/recipes/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/engine/processing/__init__.py b/src/data_designer/engine/processing/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/engine/processing/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/engine/processing/ginja/ast.py b/src/data_designer/engine/processing/ginja/ast.py index 9171365f..2d1fecb3 100644 --- a/src/data_designer/engine/processing/ginja/ast.py +++ b/src/data_designer/engine/processing/ginja/ast.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import deque +from typing import Optional, Type from jinja2 import nodes as j_nodes @@ -32,7 +33,7 @@ def ast_max_depth(node: j_nodes.Node) -> int: return max_depth -def ast_descendant_count(ast: j_nodes.Node, only_type: type[j_nodes.Node] | None = None) -> int: +def ast_descendant_count(ast: j_nodes.Node, only_type: Optional[Type[j_nodes.Node]] = None) -> int: """Count the number of nodes which descend from the given node. Args: diff --git a/src/data_designer/engine/processing/ginja/exceptions.py b/src/data_designer/engine/processing/ginja/exceptions.py index 9e796172..780ce358 100644 --- a/src/data_designer/engine/processing/ginja/exceptions.py +++ b/src/data_designer/engine/processing/ginja/exceptions.py @@ -43,9 +43,12 @@ def maybe_handle_missing_filter_exception(exception: BaseException, available_ji match = re.search(r"No filter named '([^']+)'", exc_message) if not match: return - missing_filter_name = match.group(1) - available_filter_str = ", ".join(available_jinja_filters) - raise UserTemplateUnsupportedFiltersError( - f"The Jinja2 filter `{{{{ ... | {missing_filter_name} }}}}` " - f"is not a permitted operation. Available filters: {available_filter_str}" - ) from exception + else: + missing_filter_name = match.group(1) + available_filter_str = ", ".join(available_jinja_filters) + raise UserTemplateUnsupportedFiltersError( + ( + f"The Jinja2 filter `{{{{ ... | {missing_filter_name} }}}}` " + f"is not a permitted operation. Available filters: {available_filter_str}" + ) + ) from exception diff --git a/src/data_designer/engine/processing/processors/__init__.py b/src/data_designer/engine/processing/processors/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/engine/processing/processors/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/engine/processing/processors/base.py b/src/data_designer/engine/processing/processors/base.py index 98b342f8..48c06576 100644 --- a/src/data_designer/engine/processing/processors/base.py +++ b/src/data_designer/engine/processing/processors/base.py @@ -12,4 +12,4 @@ class Processor(ConfigurableTask[TaskConfigT], ABC): def metadata() -> ConfigurableTaskMetadata: ... @abstractmethod - def process(self, data: DataT) -> DataT: ... + def process(self, data: DataT, *, current_batch_number: int | None = None) -> DataT: ... diff --git a/src/data_designer/engine/processing/processors/configs.py b/src/data_designer/engine/processing/processors/configs.py deleted file mode 100644 index 25651836..00000000 --- a/src/data_designer/engine/processing/processors/configs.py +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from pydantic import field_validator - -from data_designer.config.base import ConfigBase - - -class DropColumnsProcessorConfig(ConfigBase): - column_names: list[str] - dropped_column_parquet_file_name: str | None = None - - @field_validator("dropped_column_parquet_file_name") - def validate_dropped_column_parquet_file_name(cls, v: str | None) -> str | None: - if v is not None and not v.endswith(".parquet"): - raise ValueError("Dropped column parquet file name must end with .parquet") - return v diff --git a/src/data_designer/engine/processing/processors/drop_columns.py b/src/data_designer/engine/processing/processors/drop_columns.py index 88dd2943..bc996a9e 100644 --- a/src/data_designer/engine/processing/processors/drop_columns.py +++ b/src/data_designer/engine/processing/processors/drop_columns.py @@ -5,10 +5,10 @@ import pandas as pd +from data_designer.config.processors import DropColumnsProcessorConfig from data_designer.engine.configurable_task import ConfigurableTaskMetadata from data_designer.engine.dataset_builders.artifact_storage import BatchStage from data_designer.engine.processing.processors.base import Processor -from data_designer.engine.processing.processors.configs import DropColumnsProcessorConfig logger = logging.getLogger(__name__) @@ -22,9 +22,10 @@ def metadata() -> ConfigurableTaskMetadata: required_resources=None, ) - def process(self, data: pd.DataFrame) -> pd.DataFrame: + def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame: logger.info(f"🙈 Dropping columns: {self.config.column_names}") - self._save_dropped_columns_if_needed(data) + if current_batch_number is not None: # not in preview mode + self._save_dropped_columns_if_needed(data, current_batch_number) for column in self.config.column_names: if column in data.columns: data.drop(columns=[column], inplace=True) @@ -32,11 +33,14 @@ def process(self, data: pd.DataFrame) -> pd.DataFrame: logger.warning(f"⚠️ Cannot drop column: `{column}` not found in the dataset.") return data - def _save_dropped_columns_if_needed(self, data: pd.DataFrame) -> None: - if self.config.dropped_column_parquet_file_name: - logger.debug("📦 Saving dropped columns to dropped-columns directory") - self.artifact_storage.write_parquet_file( - parquet_file_name=self.config.dropped_column_parquet_file_name, - dataframe=data[self.config.column_names], - batch_stage=BatchStage.DROPPED_COLUMNS, - ) + def _save_dropped_columns_if_needed(self, data: pd.DataFrame, current_batch_number: int) -> None: + logger.debug("📦 Saving dropped columns to dropped-columns directory") + dropped_column_parquet_file_name = self.artifact_storage.create_batch_file_path( + batch_number=current_batch_number, + batch_stage=BatchStage.DROPPED_COLUMNS, + ).name + self.artifact_storage.write_parquet_file( + parquet_file_name=dropped_column_parquet_file_name, + dataframe=data[self.config.column_names], + batch_stage=BatchStage.DROPPED_COLUMNS, + ) diff --git a/src/data_designer/engine/processing/processors/registry.py b/src/data_designer/engine/processing/processors/registry.py new file mode 100644 index 00000000..dadcbc33 --- /dev/null +++ b/src/data_designer/engine/processing/processors/registry.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.config.base import ConfigBase +from data_designer.config.processors import ( + DropColumnsProcessorConfig, + ProcessorType, +) +from data_designer.engine.processing.processors.base import Processor +from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor +from data_designer.engine.registry.base import TaskRegistry + + +class ProcessorRegistry(TaskRegistry[str, Processor, ConfigBase]): ... + + +def create_default_processor_registry() -> ProcessorRegistry: + registry = ProcessorRegistry() + registry.register(ProcessorType.DROP_COLUMNS, DropColumnsProcessor, DropColumnsProcessorConfig, False) + return registry diff --git a/src/data_designer/engine/processing/utils.py b/src/data_designer/engine/processing/utils.py index 2da38a69..3579b3bd 100644 --- a/src/data_designer/engine/processing/utils.py +++ b/src/data_designer/engine/processing/utils.py @@ -3,7 +3,7 @@ import json import logging -from typing import Any, TypeVar, overload +from typing import Any, TypeVar, Union, overload import pandas as pd @@ -25,7 +25,7 @@ def concat_datasets(datasets: list[pd.DataFrame]) -> pd.DataFrame: # Overloads to help static type checker better understand # the input/output types of the deserialize_json_values function. @overload -def deserialize_json_values(data: str) -> dict[str, Any] | list[Any] | Any: ... +def deserialize_json_values(data: str) -> Union[dict[str, Any], list[Any], Any]: ... @overload diff --git a/src/data_designer/engine/registry/__init__.py b/src/data_designer/engine/registry/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/engine/registry/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/engine/registry/base.py b/src/data_designer/engine/registry/base.py index 5d435004..5f780940 100644 --- a/src/data_designer/engine/registry/base.py +++ b/src/data_designer/engine/registry/base.py @@ -3,7 +3,7 @@ from enum import StrEnum import threading -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Type, TypeVar from data_designer.config.base import ConfigBase from data_designer.engine.configurable_task import ConfigurableTask @@ -16,14 +16,14 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]): # registered type name -> type - _registry: dict[EnumNameT, type[TaskT]] = {} + _registry: dict[EnumNameT, Type[TaskT]] = {} # type -> registered type name - _reverse_registry: dict[type[TaskT], EnumNameT] = {} + _reverse_registry: dict[Type[TaskT], EnumNameT] = {} # registered type name -> config type - _config_registry: dict[EnumNameT, type[TaskConfigT]] = {} + _config_registry: dict[EnumNameT, Type[TaskConfigT]] = {} # config type -> registered type name - _reverse_config_registry: dict[type[TaskConfigT], EnumNameT] = {} + _reverse_config_registry: dict[Type[TaskConfigT], EnumNameT] = {} # all registries are singletons _instance = None @@ -33,8 +33,8 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]): def register( cls, name: EnumNameT, - task: type[TaskT], - config: type[TaskConfigT], + task: Type[TaskT], + config: Type[TaskConfigT], raise_on_collision: bool = True, ) -> None: if cls._has_been_registered(name): @@ -52,22 +52,22 @@ def register( cls._reverse_config_registry[config] = name @classmethod - def get_task_type(cls, name: EnumNameT) -> type[TaskT]: + def get_task_type(cls, name: EnumNameT) -> Type[TaskT]: cls._raise_if_not_registered(name, cls._registry) return cls._registry[name] @classmethod - def get_config_type(cls, name: EnumNameT) -> type[TaskConfigT]: + def get_config_type(cls, name: EnumNameT) -> Type[TaskConfigT]: cls._raise_if_not_registered(name, cls._config_registry) return cls._config_registry[name] @classmethod - def get_registered_name(cls, task: type[TaskT]) -> EnumNameT: + def get_registered_name(cls, task: Type[TaskT]) -> EnumNameT: cls._raise_if_not_registered(task, cls._reverse_registry) return cls._reverse_registry[task] @classmethod - def get_for_config_type(cls, config: type[TaskConfigT]) -> type[TaskT]: + def get_for_config_type(cls, config: Type[TaskConfigT]) -> Type[TaskT]: cls._raise_if_not_registered(config, cls._reverse_config_registry) name = cls._reverse_config_registry[config] return cls.get_task_type(name) @@ -77,7 +77,7 @@ def _has_been_registered(cls, name: EnumNameT) -> bool: return name in cls._registry @classmethod - def _raise_if_not_registered(cls, key: EnumNameT | type[TaskT] | type[TaskConfigT], mapping: dict) -> None: + def _raise_if_not_registered(cls, key: EnumNameT | Type[TaskT] | Type[TaskConfigT], mapping: dict) -> None: if not (isinstance(key, StrEnum) or isinstance(key, str)): cls._raise_if_not_type(key) if key not in mapping: diff --git a/src/data_designer/engine/registry/data_designer_registry.py b/src/data_designer/engine/registry/data_designer_registry.py index c498995d..8ed2f0ba 100644 --- a/src/data_designer/engine/registry/data_designer_registry.py +++ b/src/data_designer/engine/registry/data_designer_registry.py @@ -9,6 +9,10 @@ ColumnGeneratorRegistry, create_default_column_generator_registry, ) +from data_designer.engine.processing.processors.registry import ( + ProcessorRegistry, + create_default_processor_registry, +) class DataDesignerRegistry: @@ -17,9 +21,11 @@ def __init__( *, column_generator_registry: ColumnGeneratorRegistry | None = None, column_profiler_registry: ColumnProfilerRegistry | None = None, + processor_registry: ProcessorRegistry | None = None, ): self._column_generator_registry = column_generator_registry or create_default_column_generator_registry() self._column_profiler_registry = column_profiler_registry or create_default_column_profiler_registry() + self._processor_registry = processor_registry or create_default_processor_registry() @property def column_generators(self) -> ColumnGeneratorRegistry: @@ -28,3 +34,7 @@ def column_generators(self) -> ColumnGeneratorRegistry: @property def column_profilers(self) -> ColumnProfilerRegistry: return self._column_profiler_registry + + @property + def processors(self) -> ProcessorRegistry: + return self._processor_registry diff --git a/src/data_designer/engine/resources/__init__.py b/src/data_designer/engine/resources/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/engine/resources/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/engine/resources/managed_storage.py b/src/data_designer/engine/resources/managed_storage.py index 946477a4..cd7cb0b8 100644 --- a/src/data_designer/engine/resources/managed_storage.py +++ b/src/data_designer/engine/resources/managed_storage.py @@ -118,7 +118,7 @@ def init_managed_blob_storage(assets_storage: str = "s3://gretel-managed-assets- return S3BlobStorageProvider(bucket_name=bucket_name) - if assets_storage.startswith("/"): + elif assets_storage.startswith("/"): path = Path(assets_storage) if not path.exists(): raise RuntimeError(f"Local storage path {assets_storage!r} does not exist.") diff --git a/src/data_designer/engine/resources/seed_dataset_data_store.py b/src/data_designer/engine/resources/seed_dataset_data_store.py index 4b206333..3295f41e 100644 --- a/src/data_designer/engine/resources/seed_dataset_data_store.py +++ b/src/data_designer/engine/resources/seed_dataset_data_store.py @@ -2,15 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -import os -import tempfile -from datasets import DatasetDict, load_dataset import duckdb from huggingface_hub import HfApi, HfFileSystem -import pandas as pd -from data_designer.config.utils.io_helpers import validate_dataset_file_path from data_designer.logging import quiet_noisy_logger quiet_noisy_logger("httpx") @@ -31,9 +26,6 @@ def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: ... @abstractmethod def get_dataset_uri(self, file_id: str) -> str: ... - @abstractmethod - def load_dataset(self, file_id: str) -> pd.DataFrame: ... - class LocalSeedDatasetDataStore(SeedDatasetDataStore): """Local filesystem-based dataset storage.""" @@ -44,20 +36,6 @@ def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: def get_dataset_uri(self, file_id: str) -> str: return file_id - def load_dataset(self, file_id: str) -> pd.DataFrame: - filepath = validate_dataset_file_path(file_id) - match filepath.suffix.lower(): - case ".csv": - return pd.read_csv(filepath) - case ".parquet": - return pd.read_parquet(filepath) - case ".json": - return pd.read_json(filepath, lines=True) - case ".jsonl": - return pd.read_json(filepath, lines=True) - case _: - raise ValueError("Local datasets must be CSV, Parquet, JSON, or JSONL") - class HfHubSeedDatasetDataStore(SeedDatasetDataStore): """Hugging Face and Data Store dataset storage.""" @@ -76,54 +54,6 @@ def get_dataset_uri(self, file_id: str) -> str: repo_id, filename = self._get_repo_id_and_filename(identifier) return f"{_HF_DATASETS_PREFIX}{repo_id}/{filename}" - def load_dataset(self, file_id: str) -> pd.DataFrame: - identifier = file_id.removeprefix(_HF_DATASETS_PREFIX) - repo_id, filename = self._get_repo_id_and_filename(identifier) - is_file = "." in file_id.split("/")[-1] - - self._validate_repo(repo_id) - - if is_file: - self._validate_file(repo_id, filename) - return self._download_and_load_file(repo_id, filename) - return self._download_and_load_directory(repo_id, filename) - - def _validate_repo(self, repo_id: str) -> None: - """Validate that the repository exists and is a dataset repo.""" - if not self.hfapi.repo_exists(repo_id, repo_type="dataset"): - if self.hfapi.repo_exists(repo_id, repo_type="model"): - raise FileNotFoundError(f"Repo {repo_id} is a model repo, not a dataset repo") - raise FileNotFoundError(f"Repo {repo_id} does not exist") - - def _validate_file(self, repo_id: str, filename: str) -> None: - """Validate that the file exists in the repository.""" - if not self.hfapi.file_exists(repo_id, filename, repo_type="dataset"): - raise FileNotFoundError(f"File {filename} does not exist in repo {repo_id}") - - def _download_and_load_file(self, repo_id: str, filename: str) -> pd.DataFrame: - """Download a specific file and load it as a dataset.""" - with tempfile.TemporaryDirectory() as temp_dir: - self.hfapi.hf_hub_download( - repo_id=repo_id, - filename=filename, - local_dir=temp_dir, - repo_type="dataset", - ) - return self._load_local_dataset(temp_dir) - - def _download_and_load_directory(self, repo_id: str, directory: str) -> pd.DataFrame: - """Download entire repo and load from specific subdirectory.""" - with tempfile.TemporaryDirectory() as temp_dir: - self.hfapi.snapshot_download( - repo_id=repo_id, - local_dir=temp_dir, - repo_type="dataset", - ) - dataset_path = os.path.join(temp_dir, directory) - if not os.path.exists(dataset_path): - dataset_path = temp_dir - return self._load_local_dataset(dataset_path) - def _get_repo_id_and_filename(self, identifier: str) -> tuple[str, str]: """Extract repo_id and filename from identifier.""" parts = identifier.split("/", 2) @@ -134,10 +64,3 @@ def _get_repo_id_and_filename(self, identifier: str) -> tuple[str, str]: ) repo_ns, repo_name, filename = parts return f"{repo_ns}/{repo_name}", filename - - def _load_local_dataset(self, path: str) -> pd.DataFrame: - """Load dataset from local path.""" - hf_dataset = load_dataset(path=path) - if isinstance(hf_dataset, DatasetDict): - hf_dataset = hf_dataset[list(hf_dataset.keys())[0]] - return hf_dataset.to_pandas() diff --git a/src/data_designer/engine/sampling_gen/__init__.py b/src/data_designer/engine/sampling_gen/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/engine/sampling_gen/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/engine/sampling_gen/constraints.py b/src/data_designer/engine/sampling_gen/constraints.py index 4d719de4..30ee688f 100644 --- a/src/data_designer/engine/sampling_gen/constraints.py +++ b/src/data_designer/engine/sampling_gen/constraints.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod +from typing import Type import numpy as np from numpy.typing import NDArray @@ -90,5 +91,5 @@ def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]: } -def get_constraint_checker(constraint_type: ConstraintType) -> type[ConstraintChecker]: +def get_constraint_checker(constraint_type: ConstraintType) -> Type[ConstraintChecker]: return CONSTRAINT_TYPE_TO_CHECKER[ConstraintType(constraint_type)] diff --git a/src/data_designer/engine/sampling_gen/data_sources/__init__.py b/src/data_designer/engine/sampling_gen/data_sources/__init__.py deleted file mode 100644 index 4ee5de4a..00000000 --- a/src/data_designer/engine/sampling_gen/data_sources/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - diff --git a/src/data_designer/engine/sampling_gen/data_sources/base.py b/src/data_designer/engine/sampling_gen/data_sources/base.py index 7d8961fd..acfd2e2a 100644 --- a/src/data_designer/engine/sampling_gen/data_sources/base.py +++ b/src/data_designer/engine/sampling_gen/data_sources/base.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Optional, Type, TypeVar, Union import numpy as np from numpy.typing import NDArray @@ -45,7 +45,7 @@ def postproc(series: pd.Series, convert_to: str) -> pd.Series: return series @staticmethod - def validate_data_conversion(convert_to: str | None) -> None: + def validate_data_conversion(convert_to: Optional[str]) -> None: pass @@ -71,7 +71,7 @@ def preproc(series: pd.Series, convert_to: str) -> pd.Series: return series @staticmethod - def postproc(series: pd.Series, convert_to: str | None) -> pd.Series: + def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series: if convert_to is not None: if convert_to == "int": series = series.round() @@ -79,18 +79,18 @@ def postproc(series: pd.Series, convert_to: str | None) -> pd.Series: return series @staticmethod - def validate_data_conversion(convert_to: str | None) -> None: + def validate_data_conversion(convert_to: Optional[str]) -> None: if convert_to is not None and convert_to not in ["float", "int", "str"]: raise ValueError(f"Invalid `convert_to` value: {convert_to}. Must be one of: [float, int, str]") class DatetimeFormatMixin: @staticmethod - def preproc(series: pd.Series, convert_to: str | None) -> pd.Series: + def preproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series: return series @staticmethod - def postproc(series: pd.Series, convert_to: str | None) -> pd.Series: + def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series: if convert_to is not None: return series.dt.strftime(convert_to) if series.dt.month.nunique() == 1: @@ -104,7 +104,7 @@ def postproc(series: pd.Series, convert_to: str | None) -> pd.Series: return series.apply(lambda dt: dt.isoformat()).astype(str) @staticmethod - def validate_data_conversion(convert_to: str | None) -> None: + def validate_data_conversion(convert_to: Optional[str]) -> None: if convert_to is not None: try: pd.to_datetime(pd.to_datetime("2012-12-21").strftime(convert_to)) @@ -121,7 +121,7 @@ class DataSource(ABC, Generic[GenericParamsT]): def __init__( self, params: GenericParamsT, - random_state: RadomStateT | None = None, + random_state: Optional[RadomStateT] = None, **kwargs, ): self.rng = check_random_state(random_state) @@ -130,7 +130,7 @@ def __init__( self._validate() @classmethod - def get_param_type(cls) -> type[GenericParamsT]: + def get_param_type(cls) -> Type[GenericParamsT]: return cls.__orig_bases__[-1].__args__[0] @abstractmethod @@ -138,7 +138,7 @@ def inject_data_column( self, dataframe: pd.DataFrame, column_name: str, - index: list[int] | None = None, + index: Optional[list[int]] = None, ) -> pd.DataFrame: ... @staticmethod @@ -147,11 +147,11 @@ def preproc(series: pd.Series) -> pd.Series: ... @staticmethod @abstractmethod - def postproc(series: pd.Series, convert_to: str | None) -> pd.Series: ... + def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series: ... @staticmethod @abstractmethod - def validate_data_conversion(convert_to: str | None) -> None: ... + def validate_data_conversion(convert_to: Optional[str]) -> None: ... def get_required_column_names(self) -> tuple[str, ...]: return tuple() @@ -182,7 +182,7 @@ def inject_data_column( self, dataframe: pd.DataFrame, column_name: str, - index: list[int] | None = None, + index: Optional[list[int]] = None, ) -> pd.DataFrame: index = slice(None) if index is None else index @@ -208,7 +208,7 @@ def sample(self, num_samples: int) -> NumpyArray1dT: ... class ScipyStatsSampler(Sampler[GenericParamsT], ABC): @property @abstractmethod - def distribution(self) -> stats.rv_continuous | stats.rv_discrete: ... + def distribution(self) -> Union[stats.rv_continuous, stats.rv_discrete]: ... def sample(self, num_samples: int) -> NumpyArray1dT: return self.distribution.rvs(size=num_samples, random_state=self.rng) diff --git a/src/data_designer/engine/sampling_gen/entities/email_address_utils.py b/src/data_designer/engine/sampling_gen/entities/email_address_utils.py index 17db6d2f..7dce2d3c 100644 --- a/src/data_designer/engine/sampling_gen/entities/email_address_utils.py +++ b/src/data_designer/engine/sampling_gen/entities/email_address_utils.py @@ -77,17 +77,18 @@ def get_email_domain_by_age(age: int) -> str: weights=list(email_domains_under_30.values()), k=1, )[0] - if age < 50: + elif age < 50: return random.choices( list(email_domains_30_50.keys()), weights=list(email_domains_30_50.values()), k=1, )[0] - return random.choices( - list(email_domains_over_50.keys()), - weights=list(email_domains_over_50.values()), - k=1, - )[0] + else: + return random.choices( + list(email_domains_over_50.keys()), + weights=list(email_domains_over_50.values()), + k=1, + )[0] def get_email_basename_by_name(first_name: str, middle_name: str | None, last_name: str) -> str: diff --git a/src/data_designer/engine/sampling_gen/entities/person.py b/src/data_designer/engine/sampling_gen/entities/person.py index 1b2b1f15..685916f8 100644 --- a/src/data_designer/engine/sampling_gen/entities/person.py +++ b/src/data_designer/engine/sampling_gen/entities/person.py @@ -78,10 +78,10 @@ def generate_phone_number(locale: str, age: int, postcode: str | None, style: st if locality_var < 0.6: # Exact match to postcode 60% of the time return PhoneNumber.from_zip_prefix(postcode).format(style=style) - if locality_var < 0.8: + elif locality_var < 0.8: # Nearby postcodes 20% of the time return PhoneNumber.from_zip_prefix(postcode[:4]).format(style=style) - if locality_var < 0.9: + elif locality_var < 0.9: # More distant postcodes 10% of the time return PhoneNumber.from_zip_prefix(postcode[:3]).format(style=style) # Random (population-weighted) area code 10% of the time diff --git a/src/data_designer/engine/sampling_gen/entities/phone_number.py b/src/data_designer/engine/sampling_gen/entities/phone_number.py index 422c3644..20a7394c 100644 --- a/src/data_designer/engine/sampling_gen/entities/phone_number.py +++ b/src/data_designer/engine/sampling_gen/entities/phone_number.py @@ -3,16 +3,17 @@ from pathlib import Path import random +from typing import Optional import pandas as pd from pydantic import BaseModel, Field, field_validator ZIP_AREA_CODE_DATA = pd.read_parquet(Path(__file__).parent / "assets" / "zip_area_code_map.parquet") -ZIPCODE_AREA_CODE_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["area_code"], strict=False)) -ZIPCODE_POPULATION_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["count"], strict=False)) +ZIPCODE_AREA_CODE_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["area_code"])) +ZIPCODE_POPULATION_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["count"])) -def get_area_code(zip_prefix: str | None = None) -> str: +def get_area_code(zip_prefix: Optional[str] = None) -> str: """ Sample an area code for the given ZIP code prefix, population-weighted. @@ -23,7 +24,7 @@ def get_area_code(zip_prefix: str | None = None) -> str: A sampled area code matching the prefix, population-weighted. """ if zip_prefix is None: - zipcodes, weights = zip(*ZIPCODE_POPULATION_MAP.items(), strict=False) + zipcodes, weights = zip(*ZIPCODE_POPULATION_MAP.items()) zipcode = random.choices(zipcodes, weights=weights, k=1)[0] return str(ZIPCODE_AREA_CODE_MAP[zipcode]) if len(zip_prefix) == 5: @@ -32,7 +33,7 @@ def get_area_code(zip_prefix: str | None = None) -> str: except KeyError: raise ValueError(f"ZIP code {zip_prefix} not found.") matching_zipcodes = [[z, c] for z, c in ZIPCODE_POPULATION_MAP.items() if z.startswith(zip_prefix)] - zipcodes, weights = zip(*matching_zipcodes, strict=False) + zipcodes, weights = zip(*matching_zipcodes) if not zipcodes: raise ValueError(f"No ZIP codes found with prefix {zip_prefix}.") zipcode = random.choices(zipcodes, weights=weights, k=1)[0] diff --git a/src/data_designer/engine/secret_resolver.py b/src/data_designer/engine/secret_resolver.py index 0da339eb..521064d4 100644 --- a/src/data_designer/engine/secret_resolver.py +++ b/src/data_designer/engine/secret_resolver.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence import json import logging import os @@ -30,7 +31,7 @@ def resolve(self, secret: str) -> str: try: return self._secrets[secret] except KeyError: - raise SecretResolutionError(f"No secret found with key {secret!r}") + raise SecretResolutionError(f"No secret found in secrets file with key {secret!r}") class EnvironmentResolver(SecretResolver): @@ -47,18 +48,22 @@ def resolve(self, secret: str) -> str: class CompositeResolver(SecretResolver): - _resolvers: list[SecretResolver] + _resolvers: Sequence[SecretResolver] - def __init__(self, resolvers: list[SecretResolver]): + def __init__(self, resolvers: Sequence[SecretResolver]): if len(resolvers) == 0: raise SecretResolutionError("Must provide at least one SecretResolver to CompositeResolver") self._resolvers = resolvers def resolve(self, secret: str) -> str: + errors = [] for resolver in self._resolvers: try: return resolver.resolve(secret) - except SecretResolutionError: + except SecretResolutionError as err: + errors.append(str(err)) continue - raise SecretResolutionError(f"No configured resolvers were able to resolve secret {secret!r}") + raise SecretResolutionError( + f"No configured resolvers were able to resolve secret {secret!r}: {', '.join(errors)}" + ) diff --git a/src/data_designer/engine/validators/base.py b/src/data_designer/engine/validators/base.py index d18ab678..18f9597e 100644 --- a/src/data_designer/engine/validators/base.py +++ b/src/data_designer/engine/validators/base.py @@ -2,14 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from collections.abc import Iterator +from typing import Iterator, Optional, Self from pydantic import BaseModel, ConfigDict -from typing_extensions import Self class ValidationOutput(BaseModel): - is_valid: bool | None + is_valid: Optional[bool] model_config = ConfigDict(extra="allow") diff --git a/src/data_designer/engine/validators/python.py b/src/data_designer/engine/validators/python.py index f71b678b..5b42265e 100644 --- a/src/data_designer/engine/validators/python.py +++ b/src/data_designer/engine/validators/python.py @@ -238,7 +238,7 @@ def _get_scores(stats_by_module: dict[str, dict[str, int]]) -> dict[str, float]: def _count_python_statements(file_path: str) -> int: """Count the number of statements in a Python file.""" try: - with open(file_path, encoding="utf-8") as f: + with open(file_path, "r", encoding="utf-8") as f: tree = ast.parse(f.read()) return sum(1 for node in ast.walk(tree) if isinstance(node, ast.stmt)) except Exception: diff --git a/src/data_designer/essentials/__init__.py b/src/data_designer/essentials/__init__.py index 50e21493..3be1950f 100644 --- a/src/data_designer/essentials/__init__.py +++ b/src/data_designer/essentials/__init__.py @@ -16,6 +16,7 @@ ) from ..config.config_builder import DataDesignerConfigBuilder from ..config.data_designer_config import DataDesignerConfig +from ..config.dataset_builders import BuildStage from ..config.datastore import DatastoreSettings from ..config.models import ( ImageContext, @@ -30,6 +31,7 @@ UniformDistribution, UniformDistributionParams, ) +from ..config.processors import DropColumnsProcessorConfig, ProcessorType from ..config.sampler_constraints import ColumnInequalityConstraint, ScalarInequalityConstraint from ..config.sampler_params import ( BernoulliMixtureSamplerParams, @@ -80,9 +82,11 @@ "DataDesignerColumnType", "DataDesignerConfig", "DataDesignerConfigBuilder", + "BuildStage", "DatastoreSeedDatasetReference", "DatastoreSettings", "DatetimeSamplerParams", + "DropColumnsProcessorConfig", "ExpressionColumnConfig", "GaussianSamplerParams", "ImageContext", @@ -102,6 +106,7 @@ "ModelConfig", "PersonSamplerParams", "PoissonSamplerParams", + "ProcessorType", "RemoteValidatorParams", "SamplerColumnConfig", "SamplerType", diff --git a/src/data_designer/interface/data_designer.py b/src/data_designer/interface/data_designer.py index b0d79aef..13d16cb5 100644 --- a/src/data_designer/interface/data_designer.py +++ b/src/data_designer/interface/data_designer.py @@ -204,8 +204,6 @@ def preview( except Exception as e: raise DataDesignerProfilingError(f"🛑 Error profiling preview dataset: {e}") - dataset = builder.drop_columns_if_needed(dataset) - if len(dataset) > 0 and isinstance(analysis, DatasetProfilerResults) and len(analysis.column_statistics) > 0: logger.info(f"{RandomEmoji.success()} Preview complete!") @@ -237,6 +235,7 @@ def _create_dataset_builder( ) -> ColumnWiseDatasetBuilder: return ColumnWiseDatasetBuilder( column_configs=compile_dataset_builder_column_configs(config_builder.build(raise_exceptions=True)), + processor_configs=config_builder.get_processor_configs(), resource_provider=resource_provider, ) diff --git a/src/data_designer/logging.py b/src/data_designer/logging.py index e50bf752..77a8210d 100644 --- a/src/data_designer/logging.py +++ b/src/data_designer/logging.py @@ -6,7 +6,7 @@ from pathlib import Path import random import sys -from typing import TextIO +from typing import TextIO, Union from pythonjsonlogger import jsonlogger @@ -19,7 +19,7 @@ class LoggerConfig: @dataclass class OutputConfig: - destination: TextIO | Path + destination: Union[TextIO, Path] structured: bool diff --git a/tests/config/analysis/utils/test_reporting.py b/tests/config/analysis/utils/test_reporting.py index 10625a6a..ec9a9912 100644 --- a/tests/config/analysis/utils/test_reporting.py +++ b/tests/config/analysis/utils/test_reporting.py @@ -60,7 +60,7 @@ def test_generate_analysis_report_with_save_path_html(sample_dataset_profiler_re assert Path(tmp_path).stat().st_size > 0 # Verify it's valid HTML by checking for basic HTML structure - with open(tmp_path) as f: + with open(tmp_path, "r") as f: content = f.read() assert " 0 # Verify it's valid SVG by checking for SVG structure - with open(tmp_path) as f: + with open(tmp_path, "r") as f: content = f.read() assert " 0 # Verify it's valid HTML by checking for basic HTML structure - with open(tmp_path) as f: + with open(tmp_path, "r") as f: content = f.read() assert " DataDesignerDatasetProfiler: # Ensure the final dataset path exists - with open(stub_dataset_path / "column_configs.json") as f: + with open(stub_dataset_path / "column_configs.json", "r") as f: column_configs = json.load(f) model_config = Mock(spec=ModelConfig) diff --git a/tests/engine/dataset_builders/test_artifact_storage.py b/tests/engine/dataset_builders/test_artifact_storage.py index d721ef7e..be944432 100644 --- a/tests/engine/dataset_builders/test_artifact_storage.py +++ b/tests/engine/dataset_builders/test_artifact_storage.py @@ -105,7 +105,7 @@ def test_artifact_storage_write_metadata(stub_artifact_storage): assert file_path.exists() assert file_path == stub_artifact_storage.metadata_file_path - with open(file_path) as f: + with open(file_path, "r") as f: loaded_metadata = json.load(f) assert loaded_metadata == metadata @@ -126,7 +126,7 @@ def model_dump(self, mode="json"): assert file_path.name == "configs.json" assert file_path.parent == stub_artifact_storage.base_dataset_path - with open(file_path) as f: + with open(file_path, "r") as f: loaded_configs = json.load(f) expected = [{"name": "config1", "value": 1}, {"name": "config2", "value": 2}] assert loaded_configs == expected diff --git a/tests/engine/dataset_builders/test_column_wise_builder.py b/tests/engine/dataset_builders/test_column_wise_builder.py index e062cf21..487efca7 100644 --- a/tests/engine/dataset_builders/test_column_wise_builder.py +++ b/tests/engine/dataset_builders/test_column_wise_builder.py @@ -7,6 +7,8 @@ import pytest from data_designer.config.columns import LLMTextColumnConfig, SamplerColumnConfig +from data_designer.config.dataset_builders import BuildStage +from data_designer.config.processors import DropColumnsProcessorConfig from data_designer.engine.dataset_builders.column_wise_builder import ( MAX_CONCURRENCY_PER_NON_LLM_GENERATOR, ColumnWiseDatasetBuilder, @@ -18,32 +20,53 @@ @pytest.fixture def stub_test_column_configs(): - return [LLMTextColumnConfig(name="test_column", prompt="Test prompt", model_alias="test_model")] + return [ + LLMTextColumnConfig(name="test_column", prompt="Test prompt", model_alias="test_model"), + LLMTextColumnConfig(name="column_to_drop", prompt="Test prompt", model_alias="test_model"), + ] + + +@pytest.fixture +def stub_test_processor_configs(): + return [DropColumnsProcessorConfig(build_stage=BuildStage.POST_BATCH, column_names=["column_to_drop"])] @pytest.fixture def stub_batch_manager(): mock_batch_manager = Mock() mock_batch_manager.num_batches = 2 - mock_batch_manager.num_records_batch = 10 + mock_batch_manager.num_records_batch = 3 mock_batch_manager.finish = Mock() mock_batch_manager.write = Mock() mock_batch_manager.add_records = Mock() mock_batch_manager.update_records = Mock() mock_batch_manager.update_record = Mock() - mock_batch_manager.get_current_batch = Mock(return_value=pd.DataFrame({"existing": [1, 2, 3]})) - mock_batch_manager.get_current_batch_number = Mock(return_value=1) + mock_batch_manager.get_current_batch = Mock() + mock_batch_manager.get_current_batch.side_effect = [ + pd.DataFrame({"test_column": [1, 2, 3], "column_to_drop": [1, 2, 3]}), + pd.DataFrame({"test_column": [4, 5, 6], "column_to_drop": [4, 5, 6]}), + ] + mock_batch_manager.get_current_batch_number = Mock() + mock_batch_manager.get_current_batch_number.side_effect = [1, 2] return mock_batch_manager @pytest.fixture -def stub_column_wise_builder(stub_resource_provider, stub_test_column_configs): - return ColumnWiseDatasetBuilder(column_configs=stub_test_column_configs, resource_provider=stub_resource_provider) +def stub_column_wise_builder(stub_resource_provider, stub_test_column_configs, stub_test_processor_configs): + return ColumnWiseDatasetBuilder( + column_configs=stub_test_column_configs, + processor_configs=stub_test_processor_configs, + resource_provider=stub_resource_provider, + ) -def test_column_wise_dataset_builder_creation(stub_resource_provider, stub_test_column_configs): +def test_column_wise_dataset_builder_creation( + stub_resource_provider, stub_test_column_configs, stub_test_processor_configs +): builder = ColumnWiseDatasetBuilder( - column_configs=stub_test_column_configs, resource_provider=stub_resource_provider + column_configs=stub_test_column_configs, + processor_configs=stub_test_processor_configs, + resource_provider=stub_resource_provider, ) assert builder._column_configs == stub_test_column_configs @@ -51,11 +74,16 @@ def test_column_wise_dataset_builder_creation(stub_resource_provider, stub_test_ assert isinstance(builder._registry, DataDesignerRegistry) -def test_column_wise_dataset_builder_creation_with_custom_registry(stub_resource_provider, stub_test_column_configs): +def test_column_wise_dataset_builder_creation_with_custom_registry( + stub_resource_provider, stub_test_column_configs, stub_test_processor_configs +): custom_registry = Mock(spec=DataDesignerRegistry) builder = ColumnWiseDatasetBuilder( - column_configs=stub_test_column_configs, resource_provider=stub_resource_provider, registry=custom_registry + column_configs=stub_test_column_configs, + processor_configs=stub_test_processor_configs, + resource_provider=stub_resource_provider, + registry=custom_registry, ) assert builder._registry == custom_registry @@ -85,18 +113,26 @@ def test_column_wise_dataset_builder_batch_manager_initialization(stub_column_wi ], ) def test_column_wise_dataset_builder_single_column_configs_property( - stub_resource_provider, config_type, expected_single_configs + stub_resource_provider, stub_test_processor_configs, config_type, expected_single_configs ): if config_type == "single": single_config = LLMTextColumnConfig(name="test_column", prompt="Test prompt", model_alias="test_model") - builder = ColumnWiseDatasetBuilder(column_configs=[single_config], resource_provider=stub_resource_provider) + builder = ColumnWiseDatasetBuilder( + column_configs=[single_config], + processor_configs=stub_test_processor_configs, + resource_provider=stub_resource_provider, + ) assert builder.single_column_configs == [single_config] else: sampler_config = SamplerColumnConfig( name="sampler_col", sampler_type="category", params={"values": ["A", "B", "C"]} ) multi_config = SamplerMultiColumnConfig(columns=[sampler_config]) - builder = ColumnWiseDatasetBuilder(column_configs=[multi_config], resource_provider=stub_resource_provider) + builder = ColumnWiseDatasetBuilder( + column_configs=[multi_config], + processor_configs=stub_test_processor_configs, + resource_provider=stub_resource_provider, + ) assert builder.single_column_configs == [sampler_config] @@ -135,7 +171,9 @@ def test_column_wise_dataset_builder_build_method_basic_flow( ), ], ) -def test_column_wise_dataset_builder_validate_column_configs(stub_resource_provider, column_configs, expected_error): +def test_column_wise_dataset_builder_validate_column_configs( + stub_test_processor_configs, stub_resource_provider, column_configs, expected_error +): if expected_error == "The first column config must be a from-scratch column generator": mock_registry = Mock() mock_generator_class = Mock() @@ -144,11 +182,28 @@ def test_column_wise_dataset_builder_validate_column_configs(stub_resource_provi with pytest.raises(DatasetGenerationError, match=expected_error): ColumnWiseDatasetBuilder( - column_configs=column_configs, resource_provider=stub_resource_provider, registry=mock_registry + column_configs=column_configs, + processor_configs=stub_test_processor_configs, + resource_provider=stub_resource_provider, + registry=mock_registry, ) else: with pytest.raises(DatasetGenerationError, match=expected_error): - ColumnWiseDatasetBuilder(column_configs=column_configs, resource_provider=stub_resource_provider) + ColumnWiseDatasetBuilder( + column_configs=column_configs, + processor_configs=stub_test_processor_configs, + resource_provider=stub_resource_provider, + ) + + +def test_column_wise_dataset_builder_initialize_processors(stub_column_wise_builder): + processors = stub_column_wise_builder._processors + assert processors.keys() == set(BuildStage) + assert len(processors[BuildStage.PRE_BATCH]) == 0 + assert len(processors[BuildStage.POST_BATCH]) == 1 + assert len(processors[BuildStage.PRE_GENERATION]) == 0 + assert len(processors[BuildStage.POST_GENERATION]) == 0 + assert processors[BuildStage.POST_BATCH][0].config.column_names == ["column_to_drop"] def test_constants_max_concurrency_constant(): diff --git a/tests/engine/processing/ginja/test_ast.py b/tests/engine/processing/ginja/test_ast.py index 198cb159..b4c59163 100644 --- a/tests/engine/processing/ginja/test_ast.py +++ b/tests/engine/processing/ginja/test_ast.py @@ -34,7 +34,20 @@ def stub_name_node(): ], ) def test_ast_max_depth(stub_node, test_case, children_structure, expected_depth): - if test_case == "three_levels" or test_case == "unbalanced_tree": + if test_case == "three_levels": + root = Mock(spec=j_nodes.Node) + child1 = Mock(spec=j_nodes.Node) + child2 = Mock(spec=j_nodes.Node) + grandchild = Mock(spec=j_nodes.Node) + + grandchild.iter_child_nodes.return_value = [] + child1.iter_child_nodes.return_value = [grandchild] + child2.iter_child_nodes.return_value = [] + root.iter_child_nodes.return_value = [child1, child2] + + result = ast_max_depth(root) + assert result == expected_depth + elif test_case == "unbalanced_tree": root = Mock(spec=j_nodes.Node) child1 = Mock(spec=j_nodes.Node) child2 = Mock(spec=j_nodes.Node) diff --git a/tests/engine/processing/ginja/test_environment.py b/tests/engine/processing/ginja/test_environment.py index 4883317c..e723b087 100644 --- a/tests/engine/processing/ginja/test_environment.py +++ b/tests/engine/processing/ginja/test_environment.py @@ -199,10 +199,11 @@ def __init__(self, template_1: str, template_2: str = None): def bar(self, record): if template_2 is None: return [self.render_template(record) for _ in range(n)] - return [ - self.render_multi_template("template_1", record), - self.render_multi_template("template_2", record), - ] + else: + return [ + self.render_multi_template("template_1", record), + self.render_multi_template("template_2", record), + ] if test_case.startswith("valid"): f = Foo(template_1, template_2) diff --git a/tests/engine/processing/processors/test_drop_columns.py b/tests/engine/processing/processors/test_drop_columns.py index 0f74d87b..d1fda06c 100644 --- a/tests/engine/processing/processors/test_drop_columns.py +++ b/tests/engine/processing/processors/test_drop_columns.py @@ -6,21 +6,27 @@ import pandas as pd import pytest +from data_designer.config.dataset_builders import BuildStage +from data_designer.config.processors import DropColumnsProcessorConfig from data_designer.engine.dataset_builders.artifact_storage import BatchStage -from data_designer.engine.processing.processors.configs import DropColumnsProcessorConfig from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor @pytest.fixture def stub_processor_config(): - return DropColumnsProcessorConfig(column_names=["col1", "col2"]) + return DropColumnsProcessorConfig(build_stage=BuildStage.POST_BATCH, column_names=["col1", "col2"]) @pytest.fixture def stub_processor(stub_processor_config): mock_resource_provider = Mock() mock_resource_provider.artifact_storage = Mock() - processor = DropColumnsProcessor(config=stub_processor_config, resource_provider=mock_resource_provider) + mock_resource_provider.artifact_storage.create_batch_file_path = Mock() + mock_resource_provider.artifact_storage.create_batch_file_path.return_value.name = "dropped.parquet" + processor = DropColumnsProcessor( + config=stub_processor_config, + resource_provider=mock_resource_provider, + ) return processor @@ -29,11 +35,6 @@ def stub_empty_dataframe(): return pd.DataFrame() -def test_config_parquet_filename_validation_invalid(): - with pytest.raises(ValueError, match="Dropped column parquet file name must end with .parquet"): - DropColumnsProcessorConfig(column_names=["col1"], dropped_column_parquet_file_name="test.txt") - - def test_metadata(): metadata = DropColumnsProcessor.metadata() @@ -106,12 +107,11 @@ def test_process_logging(stub_processor, stub_sample_dataframe): mock_logger.info.assert_called_once_with("🙈 Dropping columns: ['col1', 'col2']") -def test_save_dropped_columns_with_filename(stub_processor, stub_sample_dataframe): +def test_save_dropped_columns_without_preview(stub_processor, stub_sample_dataframe): stub_processor.config.column_names = ["col1", "col2"] - stub_processor.config.dropped_column_parquet_file_name = "dropped.parquet" with patch("data_designer.engine.processing.processors.drop_columns.logger") as mock_logger: - stub_processor.process(stub_sample_dataframe.copy()) + stub_processor.process(stub_sample_dataframe.copy(), current_batch_number=0) stub_processor.artifact_storage.write_parquet_file.assert_called_once() call_args = stub_processor.artifact_storage.write_parquet_file.call_args @@ -126,9 +126,8 @@ def test_save_dropped_columns_with_filename(stub_processor, stub_sample_datafram mock_logger.debug.assert_called_once_with("📦 Saving dropped columns to dropped-columns directory") -def test_save_dropped_columns_without_filename(stub_processor, stub_sample_dataframe): +def test_save_dropped_columns_with_preview(stub_processor, stub_sample_dataframe): stub_processor.config.column_names = ["col1", "col2"] - stub_processor.config.dropped_column_parquet_file_name = None stub_processor.process(stub_sample_dataframe.copy()) stub_processor.artifact_storage.write_parquet_file.assert_not_called() @@ -136,11 +135,10 @@ def test_save_dropped_columns_without_filename(stub_processor, stub_sample_dataf def test_save_dropped_columns_with_nonexistent_columns(stub_processor, stub_sample_dataframe): stub_processor.config.column_names = ["nonexistent1", "nonexistent2"] - stub_processor.config.dropped_column_parquet_file_name = "dropped.parquet" with patch("data_designer.engine.processing.processors.drop_columns.logger"): with pytest.raises(KeyError): - stub_processor.process(stub_sample_dataframe.copy()) + stub_processor.process(stub_sample_dataframe.copy(), current_batch_number=0) def test_process_inplace_modification(stub_processor, stub_sample_dataframe): diff --git a/tests/engine/processing/processors/test_registry.py b/tests/engine/processing/processors/test_registry.py new file mode 100644 index 00000000..41ccf5a8 --- /dev/null +++ b/tests/engine/processing/processors/test_registry.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.config.processors import DropColumnsProcessorConfig, ProcessorType +from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor +from data_designer.engine.processing.processors.registry import ( + ProcessorRegistry, + create_default_processor_registry, +) + + +def test_create_default_processor_registry(): + registry = create_default_processor_registry() + + assert isinstance(registry, ProcessorRegistry) + assert ProcessorType.DROP_COLUMNS in ProcessorRegistry._registry + assert ProcessorRegistry._registry[ProcessorType.DROP_COLUMNS] == DropColumnsProcessor + assert ProcessorRegistry._config_registry[ProcessorType.DROP_COLUMNS] == DropColumnsProcessorConfig diff --git a/tests/engine/resources/test_remote_seed_dataset_data_store.py b/tests/engine/resources/test_remote_seed_dataset_data_store.py deleted file mode 100644 index e94a1e33..00000000 --- a/tests/engine/resources/test_remote_seed_dataset_data_store.py +++ /dev/null @@ -1,194 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import os -from unittest.mock import Mock, patch - -import pandas as pd -import pytest - -from data_designer.engine.resources.seed_dataset_data_store import HfHubSeedDatasetDataStore - - -@pytest.mark.skipif(not os.environ.get("PUBLIC_HF_TOKEN"), reason="PUBLIC_HF_TOKEN environment variable not set") -def test_hf_hub_seed_dataset_data_store_integration_public_huggingface_file(): - hf_store = HfHubSeedDatasetDataStore(endpoint="https://huggingface.co", token=os.environ["PUBLIC_HF_TOKEN"]) - - dataset = hf_store.load_dataset("hf://datasets/HuggingFaceFW/fineweb-2/data/aba_Latn/train/000_00000.parquet") - - assert isinstance(dataset, pd.DataFrame) - assert len(dataset) > 0 - assert len(dataset.columns) > 0 - - -@pytest.mark.skipif(not os.environ.get("PUBLIC_HF_TOKEN"), reason="PUBLIC_HF_TOKEN environment variable not set") -def test_hf_hub_seed_dataset_data_store_integration_public_huggingface_directory(): - hf_store = HfHubSeedDatasetDataStore(endpoint="https://huggingface.co", token=os.environ["PUBLIC_HF_TOKEN"]) - - dataset = hf_store.load_dataset("hf://datasets/HuggingFaceFW/fineweb-2/data/aba_Latn") - - assert isinstance(dataset, pd.DataFrame) - assert len(dataset) > 0 - assert len(dataset.columns) > 0 - - -@pytest.mark.skipif( - not os.environ.get("NVIDIA_DATASTORE_TOKEN"), reason="NVIDIA_DATASTORE_TOKEN environment variable not set" -) -def test_hf_hub_seed_dataset_data_store_integration_nvidia_datastore_file(): - datastore = HfHubSeedDatasetDataStore( - endpoint="https://datastore.int.aire.nvidia.com/v1/hf", - token=os.environ.get("NVIDIA_DATASTORE_TOKEN"), - ) - - dataset = datastore.load_dataset("hf://datasets/anesterenko/tmp-repo-777/train/folder_with_file/000_00000.parquet") - - assert isinstance(dataset, pd.DataFrame) - assert len(dataset) > 0 - assert len(dataset.columns) > 0 - - -@pytest.mark.skipif( - not os.environ.get("NVIDIA_DATASTORE_TOKEN"), reason="NVIDIA_DATASTORE_TOKEN environment variable not set" -) -def test_hf_hub_seed_dataset_data_store_integration_nvidia_datastore_directory(): - datastore = HfHubSeedDatasetDataStore( - endpoint="https://datastore.int.aire.nvidia.com/v1/hf", - token=os.environ.get("NVIDIA_DATASTORE_TOKEN"), - ) - - dataset = datastore.load_dataset("hf://datasets/anesterenko/tmp-repo-777/train/folder_with_important_files") - - assert isinstance(dataset, pd.DataFrame) - assert len(dataset) > 0 - assert len(dataset.columns) > 0 - - -def test_hf_hub_seed_dataset_data_store_integration_public_huggingface_no_token(): - hf_store = HfHubSeedDatasetDataStore(endpoint="https://huggingface.co", token=None) - - dataset = hf_store.load_dataset("hf://datasets/HuggingFaceFW/fineweb-2/data/aba_Latn/train/000_00000.parquet") - - assert isinstance(dataset, pd.DataFrame) - assert len(dataset) > 0 - - -@pytest.mark.skipif( - not os.environ.get("NVIDIA_DATASTORE_TOKEN"), reason="NVIDIA_DATASTORE_TOKEN environment variable not set" -) -def test_hf_hub_seed_dataset_data_store_integration_nvidia_datastore_no_token(): - datastore = HfHubSeedDatasetDataStore( - endpoint="https://datastore.int.aire.nvidia.com/v1/hf", - token=None, - ) - - dataset = datastore.load_dataset("hf://datasets/anesterenko/tmp-repo-777/train/folder_with_file/000_00000.parquet") - - assert isinstance(dataset, pd.DataFrame) - assert len(dataset) > 0 - - -def test_hf_hub_seed_dataset_data_store_integration_invalid_dataset_path(): - hf_store = HfHubSeedDatasetDataStore(endpoint="https://huggingface.co", token=None) - - with pytest.raises(Exception): # Should raise FileNotFoundError or similar - hf_store.load_dataset("hf://datasets/nonexistent/repo/file.parquet") - - -def test_hf_hub_seed_dataset_data_store_integration_malformed_file_id(): - hf_store = HfHubSeedDatasetDataStore(endpoint="https://huggingface.co", token=None) - - with pytest.raises(Exception): # Should raise MalformedFileIdError - hf_store.load_dataset("hf://datasets/invalid") - - -@pytest.mark.skipif(not os.environ.get("PUBLIC_HF_TOKEN"), reason="PUBLIC_HF_TOKEN environment variable not set") -def test_hf_hub_seed_dataset_data_store_integration_duckdb_connection(): - hf_store = HfHubSeedDatasetDataStore(endpoint="https://huggingface.co", token=os.environ["PUBLIC_HF_TOKEN"]) - - conn = hf_store.create_duckdb_connection() - - result = conn.execute("SELECT 1 as test").fetchone() - assert result[0] == 1 - - conn.close() - - -@pytest.mark.skipif(not os.environ.get("PUBLIC_HF_TOKEN"), reason="PUBLIC_HF_TOKEN environment variable not set") -def test_hf_hub_seed_dataset_data_store_integration_dataset_uri_generation(): - hf_store = HfHubSeedDatasetDataStore(endpoint="https://huggingface.co", token=os.environ["PUBLIC_HF_TOKEN"]) - - file_id = "hf://datasets/HuggingFaceFW/fineweb-2/data/aba_Latn/train/000_00000.parquet" - uri = hf_store.get_dataset_uri(file_id) - - assert uri == file_id # Should return the same URI for HF datasets - - -@pytest.fixture -def stub_hfapi(): - with patch("data_designer.engine.resources.seed_dataset_data_store.HfApi") as mock_api: - mock_instance = Mock() - mock_api.return_value = mock_instance - yield mock_instance - - -@pytest.fixture -def stub_hffs(): - with patch("data_designer.engine.resources.seed_dataset_data_store.HfFileSystem") as mock_fs: - mock_instance = Mock() - mock_fs.return_value = mock_instance - yield mock_instance - - -@pytest.fixture -def stub_remote_store(stub_hfapi, stub_hffs): - return HfHubSeedDatasetDataStore(endpoint="https://test.endpoint", token="test_token") - - -def test_hf_hub_seed_dataset_data_store_mocked_duckdb_connection_with_filesystem(stub_remote_store, stub_hffs): - mock_conn = Mock() - - with patch("data_designer.engine.resources.seed_dataset_data_store.duckdb") as mock_duckdb: - mock_duckdb.connect.return_value = mock_conn - - conn = stub_remote_store.create_duckdb_connection() - - mock_conn.register_filesystem.assert_called_once_with(stub_hffs) - assert conn == mock_conn - - -def test_hf_hub_seed_dataset_data_store_mocked_dataset_uri_generation(stub_remote_store): - file_id = "hf://datasets/test_namespace/test_dataset/test_file.parquet" - uri = stub_remote_store.get_dataset_uri(file_id) - - assert uri == file_id # Should return the same URI - - -def test_hf_hub_seed_dataset_data_store_mocked_load_dataset_file_success(stub_remote_store, stub_hfapi): - file_id = "hf://datasets/test_namespace/test_dataset/test_file.parquet" - stub_hfapi.repo_exists.return_value = True - stub_hfapi.file_exists.return_value = True - - with patch("data_designer.engine.resources.seed_dataset_data_store.tempfile") as mock_tempfile: - with patch("data_designer.engine.resources.seed_dataset_data_store.load_dataset") as mock_load_dataset: - mock_tempfile.TemporaryDirectory.return_value.__enter__.return_value = "/tmp/test" - - mock_hf_dataset = Mock() - mock_hf_dataset.to_pandas.return_value = pd.DataFrame({"a": [1, 2, 3]}) - mock_load_dataset.return_value = mock_hf_dataset - - result = stub_remote_store.load_dataset(file_id) - - stub_hfapi.repo_exists.assert_called_once_with("test_namespace/test_dataset", repo_type="dataset") - stub_hfapi.file_exists.assert_called_once_with( - "test_namespace/test_dataset", "test_file.parquet", repo_type="dataset" - ) - stub_hfapi.hf_hub_download.assert_called_once_with( - repo_id="test_namespace/test_dataset", - filename="test_file.parquet", - local_dir="/tmp/test", - repo_type="dataset", - ) - - assert isinstance(result, pd.DataFrame) - assert len(result) == 3 diff --git a/tests/engine/resources/test_seed_dataset_data_store.py b/tests/engine/resources/test_seed_dataset_data_store.py deleted file mode 100644 index bd3cf00f..00000000 --- a/tests/engine/resources/test_seed_dataset_data_store.py +++ /dev/null @@ -1,169 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from pathlib import Path -import tempfile -from unittest.mock import Mock, patch - -import pandas as pd -import pytest - -from data_designer.config.errors import InvalidFilePathError -from data_designer.engine.resources.seed_dataset_data_store import ( - _HF_DATASETS_PREFIX, - HfHubSeedDatasetDataStore, - LocalSeedDatasetDataStore, -) - -SEED_DATASET_DATA_STORE_MODULE = "data_designer.engine.resources.seed_dataset_data_store" - - -@pytest.fixture -def stub_sample_dataframe(): - return pd.DataFrame(data={"a": [1, 2, 3]}) - - -@pytest.fixture -def stub_hfapi(): - with patch(f"{SEED_DATASET_DATA_STORE_MODULE}.HfApi") as mock_api: - mock_instance = Mock() - mock_api.return_value = mock_instance - yield mock_instance - - -@pytest.fixture -def stub_remote_store(stub_hfapi): - return HfHubSeedDatasetDataStore(endpoint="https://test.endpoint", token="test_token") - - -@pytest.fixture -def stub_temp_base_dir(): - with tempfile.TemporaryDirectory() as tmpdir: - yield Path(tmpdir) - - -def test_local_seed_dataset_data_store_init(): - datastore = LocalSeedDatasetDataStore() - assert datastore.get_dataset_uri("test.csv") == "test.csv" - - -@pytest.mark.parametrize( - "filename,format_func", - [ - ("test.csv", lambda df, path: df.to_csv(path, index=False)), - ("test.parquet", lambda df, path: df.to_parquet(path, index=False)), - ("test.CSV", lambda df, path: df.to_csv(path, index=False)), # Case insensitive - ], -) -def test_local_load_dataset_supported_formats(filename, format_func, stub_sample_dataframe, stub_temp_base_dir): - format_func(stub_sample_dataframe, stub_temp_base_dir / filename) - - datastore = LocalSeedDatasetDataStore() - dataset = datastore.load_dataset(stub_temp_base_dir / filename) - pd.testing.assert_frame_equal(stub_sample_dataframe, dataset) - - -@pytest.mark.parametrize( - "test_case,filename,expected_error", - [ - ("unsupported_format", "test.txt", InvalidFilePathError), - ("file_not_found", "nonexistent.csv", InvalidFilePathError), - ], -) -def test_local_load_dataset_error_cases(test_case, filename, expected_error, stub_temp_base_dir): - datastore = LocalSeedDatasetDataStore() - - if test_case == "unsupported_format": - with open(stub_temp_base_dir / filename, "w") as f: - f.write("This is not a supported format") - - with pytest.raises(expected_error): - datastore.load_dataset(filename) - - -def test_hfhub_seed_dataset_data_store_init(stub_hfapi): - store = HfHubSeedDatasetDataStore(endpoint="https://custom.endpoint", token="custom_token") - assert store.hfapi == stub_hfapi - - -@pytest.mark.parametrize( - "error_type,repo_exists_return,expected_error", - [ - (FileNotFoundError, False, "Repo test_namespace/test_dataset does not exist"), - ( - FileNotFoundError, - lambda repo_id, repo_type: repo_type == "model", - "Repo test_namespace/test_dataset is a model repo, not a dataset repo", - ), - (FileNotFoundError, True, "File file.parquet does not exist in repo test_namespace/test_dataset"), - ], -) -def test_load_dataset_errors(stub_remote_store, stub_hfapi, error_type, repo_exists_return, expected_error): - file_id = f"{_HF_DATASETS_PREFIX}test_namespace/test_dataset/file.parquet" - - if callable(repo_exists_return): - stub_hfapi.repo_exists.side_effect = repo_exists_return - else: - stub_hfapi.repo_exists.return_value = repo_exists_return - - if repo_exists_return is True: - stub_hfapi.file_exists.return_value = False - - with pytest.raises(error_type, match=expected_error): - stub_remote_store.load_dataset(file_id) - - -@patch(f"{SEED_DATASET_DATA_STORE_MODULE}.load_dataset", autospec=True) -@patch(f"{SEED_DATASET_DATA_STORE_MODULE}.tempfile", autospec=True) -def test_load_dataset_file_success( - mock_tempfile, mock_load_dataset, stub_remote_store, stub_hfapi, stub_sample_dataframe -): - file_id = f"{_HF_DATASETS_PREFIX}test_namespace/test_dataset/file.parquet" - stub_hfapi.repo_exists.return_value = True - stub_hfapi.file_exists.return_value = True - - mock_temp_dir = "/tmp/test_dir" - mock_tempfile.TemporaryDirectory.return_value.__enter__.return_value = mock_temp_dir - - mock_hf_dataset = Mock() - mock_hf_dataset.to_pandas.return_value = stub_sample_dataframe - mock_load_dataset.return_value = mock_hf_dataset - - result = stub_remote_store.load_dataset(file_id) - - stub_hfapi.file_exists.assert_called_once_with("test_namespace/test_dataset", "file.parquet", repo_type="dataset") - stub_hfapi.hf_hub_download.assert_called_once_with( - repo_id="test_namespace/test_dataset", filename="file.parquet", local_dir=mock_temp_dir, repo_type="dataset" - ) - - assert isinstance(result, pd.DataFrame) - pd.testing.assert_frame_equal(result, stub_sample_dataframe) - - -@patch(f"{SEED_DATASET_DATA_STORE_MODULE}.load_dataset", autospec=True) -@patch(f"{SEED_DATASET_DATA_STORE_MODULE}.tempfile", autospec=True) -@patch(f"{SEED_DATASET_DATA_STORE_MODULE}.os.path.exists", autospec=True) -@pytest.mark.parametrize("dir_exists", [True, False]) -def test_load_dataset_directory_success( - mock_exists, mock_tempfile, mock_load_dataset, stub_remote_store, stub_hfapi, dir_exists, stub_sample_dataframe -): - dir_id = f"{_HF_DATASETS_PREFIX}test_namespace/test_dataset/directory" - stub_hfapi.repo_exists.return_value = True - mock_temp_dir = "/tmp/test_dir" - mock_tempfile.TemporaryDirectory.return_value.__enter__.return_value = mock_temp_dir - mock_exists.return_value = dir_exists - - mock_hf_dataset = Mock() - mock_hf_dataset.to_pandas.return_value = stub_sample_dataframe - mock_load_dataset.return_value = mock_hf_dataset - - result = stub_remote_store.load_dataset(dir_id) - - stub_hfapi.snapshot_download.assert_called_once_with( - repo_id="test_namespace/test_dataset", local_dir=mock_temp_dir, repo_type="dataset" - ) - - expected_path = f"{mock_temp_dir}/directory" if dir_exists else mock_temp_dir - mock_load_dataset.assert_called_once_with(path=expected_path) - assert isinstance(result, pd.DataFrame) - pd.testing.assert_frame_equal(result, stub_sample_dataframe) diff --git a/tests/engine/sampling_gen/test_jinja_utils.py b/tests/engine/sampling_gen/test_jinja_utils.py index 7acc348c..ef1a7817 100644 --- a/tests/engine/sampling_gen/test_jinja_utils.py +++ b/tests/engine/sampling_gen/test_jinja_utils.py @@ -76,7 +76,9 @@ def test_jinja_dataframe_select_index_scenarios(test_case, expr, df_data, mock_s assert result.tolist() == expected_result else: result = jdf.select_index(df) - if expected_result == "empty_index" or expected_result == "full_index": + if expected_result == "empty_index": + assert result.equals(df.index) + elif expected_result == "full_index": assert result.equals(df.index) diff --git a/tests/engine/test_configurable_task.py b/tests/engine/test_configurable_task.py index ffe0d8b1..1210b448 100644 --- a/tests/engine/test_configurable_task.py +++ b/tests/engine/test_configurable_task.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from typing import Type from unittest.mock import Mock import pandas as pd @@ -48,7 +49,7 @@ class TestConfig(ConfigBase): class TestTask(ConfigurableTask[TestConfig]): @classmethod - def get_config_type(cls) -> type[TestConfig]: + def get_config_type(cls) -> Type[TestConfig]: return TestConfig @classmethod @@ -81,7 +82,7 @@ class TestConfig(ConfigBase): class TestTask(ConfigurableTask[TestConfig]): @classmethod - def get_config_type(cls) -> type[TestConfig]: + def get_config_type(cls) -> Type[TestConfig]: return TestConfig @classmethod @@ -114,7 +115,7 @@ class TestConfig(ConfigBase): class TestTask(ConfigurableTask[TestConfig]): @classmethod - def get_config_type(cls) -> type[TestConfig]: + def get_config_type(cls) -> Type[TestConfig]: return TestConfig @classmethod diff --git a/tests/engine/test_secret_resolver.py b/tests/engine/test_secret_resolver.py index aa3fd23f..ed20eb61 100644 --- a/tests/engine/test_secret_resolver.py +++ b/tests/engine/test_secret_resolver.py @@ -30,14 +30,14 @@ def test_secrets_file_resolution(stub_secrets_file: Path): assert resolver.resolve("FOO") == "foo123" -def test_not_found(stub_secrets_file: Path): +def test_secrets_file_key_not_found(stub_secrets_file: Path): resolver = SecretsFileResolver(stub_secrets_file) with pytest.raises(SecretResolutionError): resolver.resolve("QUUX") -def test_file_doesnt_exist(): +def test_secrets_file_doesnt_exist(): # the resolver will instantiate... resolver = SecretsFileResolver(Path("/this/will/not/exist.json")) @@ -74,3 +74,16 @@ def test_composite_resolver(monkeypatch, stub_secrets_file: Path): assert resolver.resolve("FOO") == "foo000" assert resolver.resolve("BAR") == "bar789" + + +def test_composite_resolver_error(stub_secrets_file: Path): + resolvers = [EnvironmentResolver(), SecretsFileResolver(stub_secrets_file)] + + resolver = CompositeResolver(resolvers) + + with pytest.raises(SecretResolutionError) as excinfo: + resolver.resolve("QUUX") + + # The composite error message aggregates the individual resolvers' error messages + assert "env var" in str(excinfo.value) + assert "secret" in str(excinfo.value) diff --git a/tests/engine/validators/test_local_callable.py b/tests/engine/validators/test_local_callable.py index 62ab71d4..29cab887 100644 --- a/tests/engine/validators/test_local_callable.py +++ b/tests/engine/validators/test_local_callable.py @@ -17,7 +17,8 @@ def test_validate_with_callback_validator(stub_data: list[dict]): def callback_fn(df: pd.DataFrame) -> pd.DataFrame: if df.iloc[0]["text"] == "Sample text": return pd.DataFrame([{"is_valid": True, "confidence": "0.98"}]) - return pd.DataFrame([{"is_valid": False, "confidence": "0.0"}]) + else: + return pd.DataFrame([{"is_valid": False, "confidence": "0.0"}]) validator = LocalCallableValidator( LocalCallableValidatorParams( diff --git a/tests/essentials/test_init.py b/tests/essentials/test_init.py index fc1b4fdb..89f8388a 100644 --- a/tests/essentials/test_init.py +++ b/tests/essentials/test_init.py @@ -179,9 +179,9 @@ def test_conditional_imports_based_on_can_run_locally(): assert hasattr(essentials, "DataDesigner") assert hasattr(essentials, "LocalCallableValidatorParams") assert hasattr(essentials, "ModelProvider") - assert essentials.DataDesigner is not None - assert essentials.LocalCallableValidatorParams is not None - assert essentials.ModelProvider is not None + assert getattr(essentials, "DataDesigner") is not None + assert getattr(essentials, "LocalCallableValidatorParams") is not None + assert getattr(essentials, "ModelProvider") is not None assert "DataDesigner" in __all__ assert "LocalCallableValidatorParams" in __all__ assert "ModelProvider" in __all__