diff --git a/src/data_designer/config/analysis/__init__.py b/src/data_designer/config/analysis/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/config/analysis/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/config/analysis/column_profilers.py b/src/data_designer/config/analysis/column_profilers.py
index cbb292a0..0a236782 100644
--- a/src/data_designer/config/analysis/column_profilers.py
+++ b/src/data_designer/config/analysis/column_profilers.py
@@ -3,11 +3,12 @@
from abc import ABC
from enum import Enum
-from typing import TypeAlias
+from typing import Optional, Union
from pydantic import BaseModel, Field
from rich.panel import Panel
from rich.table import Column, Table
+from typing_extensions import TypeAlias
from ..base import ConfigBase
from ..utils.visualization import ColorPalette
@@ -37,20 +38,20 @@ def create_report_section(self) -> Panel:
class JudgeScoreProfilerConfig(ConfigBase):
model_alias: str
- summary_score_sample_size: int | None = Field(default=20, ge=1)
+ summary_score_sample_size: Optional[int] = Field(default=20, ge=1)
class JudgeScoreSample(BaseModel):
- score: int | str
+ score: Union[int, str]
reasoning: str
class JudgeScoreDistributions(BaseModel):
- scores: dict[str, list[int | str]]
+ scores: dict[str, list[Union[int, str]]]
reasoning: dict[str, list[str]]
distribution_types: dict[str, ColumnDistributionType]
- distributions: dict[str, CategoricalDistribution | NumericalDistribution | MissingValue]
- histograms: dict[str, CategoricalHistogramData | MissingValue]
+ distributions: dict[str, Union[CategoricalDistribution, NumericalDistribution, MissingValue]]
+ histograms: dict[str, Union[CategoricalHistogramData, MissingValue]]
class JudgeScoreSummary(BaseModel):
@@ -62,7 +63,7 @@ class JudgeScoreSummary(BaseModel):
class JudgeScoreProfilerResults(ColumnProfilerResults):
column_name: str
summaries: dict[str, JudgeScoreSummary]
- score_distributions: JudgeScoreDistributions | MissingValue
+ score_distributions: Union[JudgeScoreDistributions, MissingValue]
def create_report_section(self) -> Panel:
layout = Table.grid(Column(), expand=True, padding=(2, 0))
diff --git a/src/data_designer/config/analysis/column_statistics.py b/src/data_designer/config/analysis/column_statistics.py
index 49d6530a..991e41b9 100644
--- a/src/data_designer/config/analysis/column_statistics.py
+++ b/src/data_designer/config/analysis/column_statistics.py
@@ -5,11 +5,11 @@
from abc import ABC, abstractmethod
from enum import Enum
-from typing import Annotated, Any, Literal, TypeAlias
+from typing import Annotated, Any, Literal, Optional, Union
from pandas import Series
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
-from typing_extensions import Self
+from typing_extensions import Self, TypeAlias
from ..columns import DataDesignerColumnType
from ..sampler_params import SamplerType
@@ -39,19 +39,19 @@ def create_report_row_data(self) -> dict[str, str]: ...
class GeneralColumnStatistics(BaseColumnStatistics):
column_name: str
- num_records: int | MissingValue
- num_null: int | MissingValue
- num_unique: int | MissingValue
+ num_records: Union[int, MissingValue]
+ num_null: Union[int, MissingValue]
+ num_unique: Union[int, MissingValue]
pyarrow_dtype: str
simple_dtype: str
column_type: Literal["general"] = "general"
@field_validator("num_null", "num_unique", "num_records", mode="before")
- def general_statistics_ensure_python_integers(cls, v: int | MissingValue) -> int | MissingValue:
+ def general_statistics_ensure_python_integers(cls, v: Union[int, MissingValue]) -> Union[int, MissingValue]:
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, int)
@property
- def percent_null(self) -> float | MissingValue:
+ def percent_null(self) -> Union[float, MissingValue]:
return (
self.num_null
if self._is_missing_value(self.num_null)
@@ -59,7 +59,7 @@ def percent_null(self) -> float | MissingValue:
)
@property
- def percent_unique(self) -> float | MissingValue:
+ def percent_unique(self) -> Union[float, MissingValue]:
return (
self.num_unique
if self._is_missing_value(self.num_unique)
@@ -78,17 +78,17 @@ def _general_display_row(self) -> dict[str, str]:
def create_report_row_data(self) -> dict[str, str]:
return self._general_display_row
- def _is_missing_value(self, v: float | int | MissingValue) -> bool:
+ def _is_missing_value(self, v: Union[float, int, MissingValue]) -> bool:
return v in set(MissingValue)
class LLMTextColumnStatistics(GeneralColumnStatistics):
- completion_tokens_mean: float | MissingValue
- completion_tokens_median: float | MissingValue
- completion_tokens_stddev: float | MissingValue
- prompt_tokens_mean: float | MissingValue
- prompt_tokens_median: float | MissingValue
- prompt_tokens_stddev: float | MissingValue
+ completion_tokens_mean: Union[float, MissingValue]
+ completion_tokens_median: Union[float, MissingValue]
+ completion_tokens_stddev: Union[float, MissingValue]
+ prompt_tokens_mean: Union[float, MissingValue]
+ prompt_tokens_median: Union[float, MissingValue]
+ prompt_tokens_stddev: Union[float, MissingValue]
column_type: Literal[DataDesignerColumnType.LLM_TEXT.value] = DataDesignerColumnType.LLM_TEXT.value
@field_validator(
@@ -100,7 +100,7 @@ class LLMTextColumnStatistics(GeneralColumnStatistics):
"prompt_tokens_stddev",
mode="before",
)
- def llm_column_ensure_python_floats(cls, v: float | int | MissingValue) -> float | int | MissingValue:
+ def llm_column_ensure_python_floats(cls, v: Union[float, int, MissingValue]) -> Union[float, int, MissingValue]:
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, float)
def create_report_row_data(self) -> dict[str, Any]:
@@ -136,7 +136,7 @@ class LLMJudgedColumnStatistics(LLMTextColumnStatistics):
class SamplerColumnStatistics(GeneralColumnStatistics):
sampler_type: SamplerType
distribution_type: ColumnDistributionType
- distribution: CategoricalDistribution | NumericalDistribution | MissingValue | None
+ distribution: Optional[Union[CategoricalDistribution, NumericalDistribution, MissingValue]]
column_type: Literal[DataDesignerColumnType.SAMPLER.value] = DataDesignerColumnType.SAMPLER.value
def create_report_row_data(self) -> dict[str, str]:
@@ -148,7 +148,7 @@ def create_report_row_data(self) -> dict[str, str]:
class SeedDatasetColumnStatistics(GeneralColumnStatistics):
distribution_type: ColumnDistributionType
- distribution: CategoricalDistribution | NumericalDistribution | MissingValue | None
+ distribution: Optional[Union[CategoricalDistribution, NumericalDistribution, MissingValue]]
column_type: Literal[DataDesignerColumnType.SEED_DATASET.value] = DataDesignerColumnType.SEED_DATASET.value
def create_report_row_data(self) -> dict[str, str]:
@@ -160,15 +160,15 @@ class ExpressionColumnStatistics(GeneralColumnStatistics):
class ValidationColumnStatistics(GeneralColumnStatistics):
- num_valid_records: int | MissingValue
+ num_valid_records: Union[int, MissingValue]
column_type: Literal[DataDesignerColumnType.VALIDATION.value] = DataDesignerColumnType.VALIDATION.value
@field_validator("num_valid_records", mode="before")
- def code_validation_column_ensure_python_integers(cls, v: int | MissingValue) -> int | MissingValue:
+ def code_validation_column_ensure_python_integers(cls, v: Union[int, MissingValue]) -> Union[int, MissingValue]:
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, int)
@property
- def percent_valid(self) -> float | MissingValue:
+ def percent_valid(self) -> Union[float, MissingValue]:
return (
self.num_valid_records
if self._is_missing_value(self.num_valid_records)
@@ -181,7 +181,7 @@ def create_report_row_data(self) -> dict[str, str]:
class CategoricalHistogramData(BaseModel):
- categories: list[float | int | str]
+ categories: list[Union[float, int, str]]
counts: list[int]
@model_validator(mode="after")
@@ -198,12 +198,12 @@ def from_series(cls, series: Series) -> Self:
class CategoricalDistribution(BaseModel):
- most_common_value: str | int
- least_common_value: str | int
+ most_common_value: Union[str, int]
+ least_common_value: Union[str, int]
histogram: CategoricalHistogramData
@field_validator("most_common_value", "least_common_value", mode="before")
- def ensure_python_types(cls, v: str | int) -> str | int:
+ def ensure_python_types(cls, v: Union[str, int]) -> Union[str, int]:
return str(v) if not is_int(v) else prepare_number_for_reporting(v, int)
@classmethod
@@ -217,14 +217,14 @@ def from_series(cls, series: Series) -> Self:
class NumericalDistribution(BaseModel):
- min: float | int
- max: float | int
+ min: Union[float, int]
+ max: Union[float, int]
mean: float
stddev: float
median: float
@field_validator("min", "max", "mean", "stddev", "median", mode="before")
- def ensure_python_types(cls, v: float | int) -> float | int:
+ def ensure_python_types(cls, v: Union[float, int]) -> Union[float, int]:
return prepare_number_for_reporting(v, int if is_int(v) else float)
@classmethod
@@ -239,14 +239,16 @@ def from_series(cls, series: Series) -> Self:
ColumnStatisticsT: TypeAlias = Annotated[
- GeneralColumnStatistics
- | LLMTextColumnStatistics
- | LLMCodeColumnStatistics
- | LLMStructuredColumnStatistics
- | LLMJudgedColumnStatistics
- | SamplerColumnStatistics
- | SeedDatasetColumnStatistics
- | ValidationColumnStatistics
- | ExpressionColumnStatistics,
+ Union[
+ GeneralColumnStatistics,
+ LLMTextColumnStatistics,
+ LLMCodeColumnStatistics,
+ LLMStructuredColumnStatistics,
+ LLMJudgedColumnStatistics,
+ SamplerColumnStatistics,
+ SeedDatasetColumnStatistics,
+ ValidationColumnStatistics,
+ ExpressionColumnStatistics,
+ ],
Field(discriminator="column_type"),
]
diff --git a/src/data_designer/config/analysis/dataset_profiler.py b/src/data_designer/config/analysis/dataset_profiler.py
index 07647768..f9e0e168 100644
--- a/src/data_designer/config/analysis/dataset_profiler.py
+++ b/src/data_designer/config/analysis/dataset_profiler.py
@@ -3,6 +3,7 @@
from functools import cached_property
from pathlib import Path
+from typing import Optional, Union
from pydantic import BaseModel, Field, field_validator
@@ -18,8 +19,8 @@ class DatasetProfilerResults(BaseModel):
num_records: int
target_num_records: int
column_statistics: list[ColumnStatisticsT] = Field(..., min_length=1)
- side_effect_column_names: list[str] | None = None
- column_profiles: list[ColumnProfilerResultsT] | None = None
+ side_effect_column_names: Optional[list[str]] = None
+ column_profiles: Optional[list[ColumnProfilerResultsT]] = None
@field_validator("num_records", "target_num_records", mode="before")
def ensure_python_integers(cls, v: int) -> int:
@@ -42,8 +43,8 @@ def get_column_statistics_by_type(self, column_type: DataDesignerColumnType) ->
def to_report(
self,
- save_path: str | Path | None = None,
- include_sections: list[ReportSection | DataDesignerColumnType] | None = None,
+ save_path: Optional[Union[str, Path]] = None,
+ include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None,
) -> None:
"""Generate and print an analysis report based on the dataset profiling results.
diff --git a/src/data_designer/config/analysis/utils/__init__.py b/src/data_designer/config/analysis/utils/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/config/analysis/utils/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/config/analysis/utils/reporting.py b/src/data_designer/config/analysis/utils/reporting.py
index 35cd7687..ca62a0cd 100644
--- a/src/data_designer/config/analysis/utils/reporting.py
+++ b/src/data_designer/config/analysis/utils/reporting.py
@@ -5,7 +5,7 @@
from enum import Enum
from pathlib import Path
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional, Union
from rich.align import Align
from rich.console import Console, Group
@@ -49,8 +49,8 @@ class ReportSection(str, Enum):
def generate_analysis_report(
analysis: DatasetProfilerResults,
- save_path: str | Path | None = None,
- include_sections: list[ReportSection | DataDesignerColumnType] | None = None,
+ save_path: Optional[Union[str, Path]] = None,
+ include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None,
) -> None:
"""Generate an analysis report for dataset profiling results.
@@ -166,7 +166,7 @@ def create_judge_score_summary_table(
layout = Table.grid(Column(), Column(), expand=True, padding=(0, 2))
histogram_table = create_rich_histogram_table(
- {str(s): c for s, c in zip(histogram.categories, histogram.counts, strict=False)},
+ {str(s): c for s, c in zip(histogram.categories, histogram.counts)},
("score", "count"),
name_style=HIST_NAME_STYLE,
value_style=HIST_VALUE_STYLE,
diff --git a/src/data_designer/config/base.py b/src/data_designer/config/base.py
index 4e4f0f66..64d4d6de 100644
--- a/src/data_designer/config/base.py
+++ b/src/data_designer/config/base.py
@@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar
+from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar, Union
import pandas as pd
from pydantic import BaseModel, ConfigDict
@@ -14,9 +14,9 @@
from .utils.io_helpers import serialize_data
if TYPE_CHECKING:
- from ..client.results.preview import PreviewResults
from .analysis.dataset_profiler import DatasetProfilerResults
from .config_builder import DataDesignerConfigBuilder
+ from .preview_results import PreviewResults
DEFAULT_NUM_RECORDS = 10
@@ -66,7 +66,7 @@ def to_dict(self) -> dict[str, Any]:
"""
return self.model_dump(mode="json")
- def to_yaml(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None:
+ def to_yaml(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]:
"""Convert the configuration to a YAML string or file.
Args:
@@ -84,7 +84,7 @@ def to_yaml(self, path: str | Path | None = None, *, indent: int | None = 2, **k
with open(path, "w") as f:
f.write(yaml_str)
- def to_json(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None:
+ def to_json(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]:
"""Convert the configuration to a JSON string or file.
Args:
diff --git a/src/data_designer/config/columns.py b/src/data_designer/config/columns.py
index 5eada374..c2449499 100644
--- a/src/data_designer/config/columns.py
+++ b/src/data_designer/config/columns.py
@@ -3,10 +3,10 @@
from abc import ABC, abstractmethod
from enum import Enum
-from typing import Literal, TypeAlias
+from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, Field, model_validator
-from typing_extensions import Self
+from typing_extensions import Self, TypeAlias
from .base import ConfigBase
from .errors import InvalidColumnTypeError, InvalidConfigError
@@ -79,7 +79,7 @@ class SamplerColumnConfig(SingleColumnConfig):
sampler_type: SamplerType
params: SamplerParamsT
conditional_params: dict[str, SamplerParamsT] = {}
- convert_to: str | None = None
+ convert_to: Optional[str] = None
@property
def column_type(self) -> DataDesignerColumnType:
@@ -89,8 +89,8 @@ def column_type(self) -> DataDesignerColumnType:
class LLMTextColumnConfig(SingleColumnConfig):
prompt: str
model_alias: str
- system_prompt: str | None = None
- multi_modal_context: list[ImageContext] | None = None
+ system_prompt: Optional[str] = None
+ multi_modal_context: Optional[list[ImageContext]] = None
@property
def column_type(self) -> DataDesignerColumnType:
@@ -124,7 +124,7 @@ def column_type(self) -> DataDesignerColumnType:
class LLMStructuredColumnConfig(LLMTextColumnConfig):
- output_format: dict | type[BaseModel]
+ output_format: Union[dict, Type[BaseModel]]
@property
def column_type(self) -> DataDesignerColumnType:
@@ -140,7 +140,7 @@ def validate_output_format(self) -> Self:
class Score(ConfigBase):
name: str = Field(..., description="A clear name for this score.")
description: str = Field(..., description="An informative and detailed assessment guide for using this score.")
- options: dict[int | str, str] = Field(..., description="Score options in the format of {score: description}.")
+ options: dict[Union[int, str], str] = Field(..., description="Score options in the format of {score: description}.")
class LLMJudgeColumnConfig(LLMTextColumnConfig):
@@ -208,16 +208,16 @@ def column_type(self) -> DataDesignerColumnType:
}
-ColumnConfigT: TypeAlias = (
- ExpressionColumnConfig
- | LLMCodeColumnConfig
- | LLMJudgeColumnConfig
- | LLMStructuredColumnConfig
- | LLMTextColumnConfig
- | SamplerColumnConfig
- | SeedDatasetColumnConfig
- | ValidationColumnConfig
-)
+ColumnConfigT: TypeAlias = Union[
+ ExpressionColumnConfig,
+ LLMCodeColumnConfig,
+ LLMJudgeColumnConfig,
+ LLMStructuredColumnConfig,
+ LLMTextColumnConfig,
+ SamplerColumnConfig,
+ SeedDatasetColumnConfig,
+ ValidationColumnConfig,
+]
def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT:
@@ -234,19 +234,19 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
if column_type == DataDesignerColumnType.LLM_TEXT:
return LLMTextColumnConfig(name=name, **kwargs)
- if column_type == DataDesignerColumnType.LLM_CODE:
+ elif column_type == DataDesignerColumnType.LLM_CODE:
return LLMCodeColumnConfig(name=name, **kwargs)
- if column_type == DataDesignerColumnType.LLM_STRUCTURED:
+ elif column_type == DataDesignerColumnType.LLM_STRUCTURED:
return LLMStructuredColumnConfig(name=name, **kwargs)
- if column_type == DataDesignerColumnType.LLM_JUDGE:
+ elif column_type == DataDesignerColumnType.LLM_JUDGE:
return LLMJudgeColumnConfig(name=name, **kwargs)
- if column_type == DataDesignerColumnType.VALIDATION:
+ elif column_type == DataDesignerColumnType.VALIDATION:
return ValidationColumnConfig(name=name, **kwargs)
- if column_type == DataDesignerColumnType.EXPRESSION:
+ elif column_type == DataDesignerColumnType.EXPRESSION:
return ExpressionColumnConfig(name=name, **kwargs)
- if column_type == DataDesignerColumnType.SAMPLER:
+ elif column_type == DataDesignerColumnType.SAMPLER:
return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs))
- if column_type == DataDesignerColumnType.SEED_DATASET:
+ elif column_type == DataDesignerColumnType.SEED_DATASET:
return SeedDatasetColumnConfig(name=name, **kwargs)
raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover
diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py
index 78b2195d..a33b3658 100644
--- a/src/data_designer/config/config_builder.py
+++ b/src/data_designer/config/config_builder.py
@@ -6,6 +6,7 @@
import json
import logging
from pathlib import Path
+from typing import Optional, Union
from pygments import highlight
from pygments.formatters import HtmlFormatter
@@ -16,9 +17,11 @@
from .base import ExportableConfigBase
from .columns import ColumnConfigT, DataDesignerColumnType, SeedDatasetColumnConfig, get_column_config_from_kwargs
from .data_designer_config import DataDesignerConfig
+from .dataset_builders import BuildStage
from .datastore import DatastoreSettings, fetch_seed_dataset_column_names
from .errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError
from .models import ModelConfig, load_model_configs
+from .processors import ProcessorConfig, ProcessorType, get_processor_config_from_kwargs
from .sampler_constraints import (
ColumnConstraintT,
ColumnInequalityConstraint,
@@ -60,7 +63,7 @@ class BuilderConfig(ExportableConfigBase):
"""
data_designer: DataDesignerConfig
- datastore_settings: DatastoreSettings | None
+ datastore_settings: Optional[DatastoreSettings]
class DataDesignerConfigBuilder:
@@ -70,7 +73,7 @@ class DataDesignerConfigBuilder:
"""
@classmethod
- def from_config(cls, config: dict | str | Path | BuilderConfig) -> Self:
+ def from_config(cls, config: Union[dict, str, Path, BuilderConfig]) -> Self:
"""Create a DataDesignerConfigBuilder from an existing configuration.
Args:
@@ -117,7 +120,7 @@ def from_config(cls, config: dict | str | Path | BuilderConfig) -> Self:
return builder
- def __init__(self, model_configs: list[ModelConfig] | str | Path | None = None):
+ def __init__(self, model_configs: Optional[Union[list[ModelConfig], str, Path]] = None):
"""Initialize a new DataDesignerConfigBuilder instance.
Args:
@@ -128,11 +131,12 @@ def __init__(self, model_configs: list[ModelConfig] | str | Path | None = None):
"""
self._column_configs = {}
self._model_configs = load_model_configs(model_configs)
- self._seed_config: SeedConfig | None = None
+ self._processor_configs: list[ProcessorConfig] = []
+ self._seed_config: Optional[SeedConfig] = None
self._constraints: list[ColumnConstraintT] = []
self._profilers: list[ColumnProfilerConfigT] = []
self._info = DataDesignerInfo()
- self._datastore_settings: DatastoreSettings | None = None
+ self._datastore_settings: Optional[DatastoreSettings] = None
@property
def model_configs(self) -> list[ModelConfig]:
@@ -193,10 +197,10 @@ def delete_model_config(self, alias: str) -> Self:
def add_column(
self,
- column_config: ColumnConfigT | None = None,
+ column_config: Optional[ColumnConfigT] = None,
*,
- name: str | None = None,
- column_type: DataDesignerColumnType | None = None,
+ name: Optional[str] = None,
+ column_type: Optional[DataDesignerColumnType] = None,
**kwargs,
) -> Self:
"""Add a Data Designer column configuration to the current Data Designer configuration.
@@ -233,9 +237,9 @@ def add_column(
def add_constraint(
self,
- constraint: ColumnConstraintT | None = None,
+ constraint: Optional[ColumnConstraintT] = None,
*,
- constraint_type: ConstraintType | None = None,
+ constraint_type: Optional[ConstraintType] = None,
**kwargs,
) -> Self:
"""Add a constraint to the current Data Designer configuration.
@@ -283,6 +287,43 @@ def add_constraint(
self._constraints.append(constraint)
return self
+ def add_processor(
+ self,
+ processor_config: Optional[ProcessorConfig] = None,
+ *,
+ processor_type: Optional[ProcessorType] = None,
+ **kwargs,
+ ) -> Self:
+ """Add a processor to the current Data Designer configuration.
+
+ You can either provide a processor config object directly, or provide a processor type and
+ additional keyword arguments to construct the processor config object.
+
+ Args:
+ processor_config: The processor configuration object to add.
+ processor_type: The type of processor to add.
+ **kwargs: Additional keyword arguments to pass to the processor constructor.
+
+ Returns:
+ The current Data Designer config builder instance.
+ """
+ if processor_config is None:
+ if processor_type is None:
+ raise BuilderConfigurationError(
+ "🛑 You must provide either a 'processor_config' object or 'processor_type' "
+ "with additional keyword arguments."
+ )
+ processor_config = get_processor_config_from_kwargs(processor_type=processor_type, **kwargs)
+
+ # Checks elsewhere fail if DropColumnsProcessor drops a column but it is not marked for drop
+ if processor_config.processor_type == ProcessorType.DROP_COLUMNS:
+ for column in processor_config.column_names:
+ if column in self._column_configs:
+ self._column_configs[column].drop = True
+
+ self._processor_configs.append(processor_config)
+ return self
+
def add_profiler(self, profiler_config: ColumnProfilerConfigT) -> Self:
"""Add a profiler to the current Data Designer configuration.
@@ -331,6 +372,7 @@ def build(self, *, skip_validation: bool = False, raise_exceptions: bool = False
columns=list(self._column_configs.values()),
constraints=self._constraints or None,
profilers=self._profilers or None,
+ processors=self._processor_configs or None,
)
def delete_constraints(self, target_column: str) -> Self:
@@ -427,7 +469,15 @@ def get_columns_excluding_type(self, column_type: DataDesignerColumnType) -> lis
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
return [c for c in self._column_configs.values() if c.column_type != column_type]
- def get_seed_config(self) -> SeedConfig | None:
+ def get_processor_configs(self) -> dict[BuildStage, list[ProcessorConfig]]:
+ """Get processor configuration objects.
+
+ Returns:
+ A dictionary of processor configuration objects by dataset builder stage.
+ """
+ return self._processor_configs
+
+ def get_seed_config(self) -> Optional[SeedConfig]:
"""Get the seed config for the current Data Designer configuration.
Returns:
@@ -435,7 +485,7 @@ def get_seed_config(self) -> SeedConfig | None:
"""
return self._seed_config
- def get_seed_datastore_settings(self) -> DatastoreSettings | None:
+ def get_seed_datastore_settings(self) -> Optional[DatastoreSettings]:
"""Get most recent datastore settings for the current Data Designer configuration.
Returns:
@@ -454,7 +504,7 @@ def num_columns_of_type(self, column_type: DataDesignerColumnType) -> int:
"""
return len(self.get_columns_of_type(column_type))
- def set_seed_datastore_settings(self, datastore_settings: DatastoreSettings | None) -> Self:
+ def set_seed_datastore_settings(self, datastore_settings: Optional[DatastoreSettings]) -> Self:
"""Set the datastore settings for the seed dataset.
Args:
@@ -477,7 +527,9 @@ def validate(self, *, raise_exceptions: bool = False) -> Self:
"""
violations = validate_data_designer_config(
- columns=list(self._column_configs.values()), allowed_references=self.allowed_references
+ columns=list(self._column_configs.values()),
+ processor_configs=self._processor_configs,
+ allowed_references=self.allowed_references,
)
rich_print_violations(violations)
if raise_exceptions and len([v for v in violations if v.level == ViolationLevel.ERROR]) > 0:
@@ -516,7 +568,7 @@ def with_seed_dataset(
self._column_configs[column_name] = SeedDatasetColumnConfig(name=column_name)
return self
- def write_config(self, path: str | Path, indent: int | None = 2, **kwargs) -> None:
+ def write_config(self, path: Union[str, Path], indent: Optional[int] = 2, **kwargs) -> None:
"""Write the current configuration to a file.
Args:
diff --git a/src/data_designer/config/data_designer_config.py b/src/data_designer/config/data_designer_config.py
index 8579ecd9..ba717fd8 100644
--- a/src/data_designer/config/data_designer_config.py
+++ b/src/data_designer/config/data_designer_config.py
@@ -3,12 +3,15 @@
from __future__ import annotations
+from typing import Optional
+
from pydantic import Field
from .analysis.column_profilers import ColumnProfilerConfigT
from .base import ExportableConfigBase
from .columns import ColumnConfigT
from .models import ModelConfig
+from .processors import ProcessorConfig
from .sampler_constraints import ColumnConstraintT
from .seed import SeedConfig
@@ -30,7 +33,8 @@ class DataDesignerConfig(ExportableConfigBase):
"""
columns: list[ColumnConfigT] = Field(min_length=1)
- model_configs: list[ModelConfig] | None = None
- seed_config: SeedConfig | None = None
- constraints: list[ColumnConstraintT] | None = None
- profilers: list[ColumnProfilerConfigT] | None = None
+ model_configs: Optional[list[ModelConfig]] = None
+ seed_config: Optional[SeedConfig] = None
+ constraints: Optional[list[ColumnConstraintT]] = None
+ profilers: Optional[list[ColumnProfilerConfigT]] = None
+ processors: Optional[list[ProcessorConfig]] = None
diff --git a/src/data_designer/config/dataset_builders.py b/src/data_designer/config/dataset_builders.py
new file mode 100644
index 00000000..aa18df7f
--- /dev/null
+++ b/src/data_designer/config/dataset_builders.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from enum import Enum
+
+
+class BuildStage(str, Enum):
+ PRE_BATCH = "pre_batch"
+ POST_BATCH = "post_batch"
+ PRE_GENERATION = "pre_generation"
+ POST_GENERATION = "post_generation"
diff --git a/src/data_designer/config/datastore.py b/src/data_designer/config/datastore.py
index fbf17eed..d1ea0c07 100644
--- a/src/data_designer/config/datastore.py
+++ b/src/data_designer/config/datastore.py
@@ -5,7 +5,7 @@
import logging
from pathlib import Path
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional, Union
from huggingface_hub import HfApi, HfFileSystem
import pandas as pd
@@ -28,17 +28,18 @@ class DatastoreSettings(BaseModel):
...,
description="Datastore endpoint. Use 'https://huggingface.co' for the Hugging Face Hub.",
)
- token: str | None = Field(default=None, description="If needed, token to use for authentication.")
+ token: Optional[str] = Field(default=None, description="If needed, token to use for authentication.")
-def get_file_column_names(file_path: str | Path, file_type: str) -> list[str]:
+def get_file_column_names(file_path: Union[str, Path], file_type: str) -> list[str]:
"""Extract column names based on file type."""
if file_type == "parquet":
try:
schema = pq.read_schema(file_path)
if hasattr(schema, "names"):
return schema.names
- return [field.name for field in schema]
+ else:
+ return [field.name for field in schema]
except Exception as e:
logger.warning(f"Failed to process parquet file {file_path}: {e}")
return []
@@ -70,13 +71,16 @@ def resolve_datastore_settings(datastore_settings: DatastoreSettings | dict | No
raise InvalidConfigError("🛑 Datastore settings are required in order to upload datasets to the datastore.")
if isinstance(datastore_settings, DatastoreSettings):
return datastore_settings
- if isinstance(datastore_settings, dict):
+ elif isinstance(datastore_settings, dict):
return DatastoreSettings.model_validate(datastore_settings)
- raise InvalidConfigError("🛑 Invalid datastore settings format. Must be DatastoreSettings object or dictionary.")
+ else:
+ raise InvalidConfigError(
+ "🛑 Invalid datastore settings format. Must be DatastoreSettings object or dictionary."
+ )
def upload_to_hf_hub(
- dataset_path: str | Path,
+ dataset_path: Union[str, Path],
filename: str,
repo_id: str,
datastore_settings: DatastoreSettings,
@@ -105,7 +109,7 @@ def upload_to_hf_hub(
def _fetch_seed_dataset_column_names_from_datastore(
repo_id: str,
filename: str,
- datastore_settings: DatastoreSettings | dict | None = None,
+ datastore_settings: Optional[Union[DatastoreSettings, dict]] = None,
) -> list[str]:
file_type = filename.split(".")[-1]
if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS:
@@ -123,7 +127,7 @@ def _fetch_seed_dataset_column_names_from_local_file(dataset_path: str | Path) -
return get_file_column_names(dataset_path, dataset_path.suffix.lower()[1:])
-def _validate_dataset_path(dataset_path: str | Path) -> Path:
+def _validate_dataset_path(dataset_path: Union[str, Path]) -> Path:
if not Path(dataset_path).is_file():
raise InvalidFilePathError("🛑 To upload a dataset to the datastore, you must provide a valid file path.")
if not Path(dataset_path).name.endswith(tuple(VALID_DATASET_FILE_EXTENSIONS)):
diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py
index b6bf26d6..4fedb0c9 100644
--- a/src/data_designer/config/models.py
+++ b/src/data_designer/config/models.py
@@ -4,11 +4,11 @@
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
-from typing import Any, Generic, TypeAlias, TypeVar
+from typing import Any, Generic, List, Optional, TypeVar, Union
import numpy as np
from pydantic import BaseModel, Field, model_validator
-from typing_extensions import Self
+from typing_extensions import Self, TypeAlias
from .base import ConfigBase
from .errors import InvalidConfigError
@@ -49,7 +49,7 @@ def get_context(self, record: dict) -> dict[str, Any]: ...
class ImageContext(ModalityContext):
modality: Modality = Modality.IMAGE
- image_format: ImageFormat | None = None
+ image_format: Optional[ImageFormat] = None
def get_context(self, record: dict) -> dict[str, Any]:
context = dict(type="image_url")
@@ -82,8 +82,8 @@ def sample(self) -> float: ...
class ManualDistributionParams(ConfigBase):
- values: list[float] = Field(min_length=1)
- weights: list[float] | None = None
+ values: List[float] = Field(min_length=1)
+ weights: Optional[List[float]] = None
@model_validator(mode="after")
def _normalize_weights(self) -> Self:
@@ -99,7 +99,7 @@ def _validate_equal_lengths(self) -> Self:
class ManualDistribution(Distribution[ManualDistributionParams]):
- distribution_type: DistributionType | None = "manual"
+ distribution_type: Optional[DistributionType] = "manual"
params: ManualDistributionParams
def sample(self) -> float:
@@ -118,26 +118,26 @@ def _validate_low_lt_high(self) -> Self:
class UniformDistribution(Distribution[UniformDistributionParams]):
- distribution_type: DistributionType | None = "uniform"
+ distribution_type: Optional[DistributionType] = "uniform"
params: UniformDistributionParams
def sample(self) -> float:
return float(np.random.uniform(low=self.params.low, high=self.params.high, size=1)[0])
-DistributionT: TypeAlias = UniformDistribution | ManualDistribution
+DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution]
class InferenceParameters(ConfigBase):
- temperature: float | DistributionT | None = None
- top_p: float | DistributionT | None = None
- max_tokens: int | None = Field(default=None, ge=1)
+ temperature: Optional[Union[float, DistributionT]] = None
+ top_p: Optional[Union[float, DistributionT]] = None
+ max_tokens: Optional[int] = Field(default=None, ge=1)
max_parallel_requests: int = Field(default=4, ge=1)
- timeout: int | None = Field(default=None, ge=1)
- extra_body: dict[str, Any] | None = None
+ timeout: Optional[int] = Field(default=None, ge=1)
+ extra_body: Optional[dict[str, Any]] = None
@property
- def generate_kwargs(self) -> dict[str, float | int]:
+ def generate_kwargs(self) -> dict[str, Union[float, int]]:
result = {}
if self.temperature is not None:
result["temperature"] = (
@@ -173,7 +173,7 @@ def _validate_top_p(self) -> Self:
def _run_validation(
self,
- value: float | DistributionT | None,
+ value: Union[float, DistributionT, None],
param_name: str,
min_value: float,
max_value: float,
@@ -201,10 +201,10 @@ class ModelConfig(ConfigBase):
alias: str
model: str
inference_parameters: InferenceParameters
- provider: str | None = None
+ provider: Optional[str] = None
-def load_model_configs(model_configs: list[ModelConfig] | str | Path | None) -> list[ModelConfig]:
+def load_model_configs(model_configs: Union[list[ModelConfig], str, Path, None]) -> list[ModelConfig]:
if model_configs is None:
return []
if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs):
diff --git a/src/data_designer/config/preview_results.py b/src/data_designer/config/preview_results.py
index ffe7ba20..41313c52 100644
--- a/src/data_designer/config/preview_results.py
+++ b/src/data_designer/config/preview_results.py
@@ -3,6 +3,8 @@
from __future__ import annotations
+from typing import Optional
+
import pandas as pd
from .analysis.dataset_profiler import DatasetProfilerResults
@@ -15,8 +17,8 @@ def __init__(
self,
*,
config_builder: DataDesignerConfigBuilder,
- dataset: pd.DataFrame | None = None,
- analysis: DatasetProfilerResults | None = None,
+ dataset: Optional[pd.DataFrame] = None,
+ analysis: Optional[DatasetProfilerResults] = None,
):
"""Creates a new instance with results from a Data Designer preview run.
diff --git a/src/data_designer/config/processors.py b/src/data_designer/config/processors.py
new file mode 100644
index 00000000..163956a3
--- /dev/null
+++ b/src/data_designer/config/processors.py
@@ -0,0 +1,41 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from abc import ABC
+from enum import Enum
+from typing import Literal
+
+from pydantic import Field, field_validator
+
+from .base import ConfigBase
+from .dataset_builders import BuildStage
+
+SUPPORTED_STAGES = [BuildStage.POST_BATCH]
+
+
+class ProcessorType(str, Enum):
+ DROP_COLUMNS = "drop_columns"
+
+
+class ProcessorConfig(ConfigBase, ABC):
+ build_stage: BuildStage = Field(
+ ..., description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}"
+ )
+
+ @field_validator("build_stage")
+ def validate_build_stage(cls, v: BuildStage) -> BuildStage:
+ if v not in SUPPORTED_STAGES:
+ raise ValueError(
+ f"Invalid dataset builder stage: {v}. Only these stages are supported: {', '.join(SUPPORTED_STAGES)}"
+ )
+ return v
+
+
+def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs) -> ProcessorConfig:
+ if processor_type == ProcessorType.DROP_COLUMNS:
+ return DropColumnsProcessorConfig(**kwargs)
+
+
+class DropColumnsProcessorConfig(ProcessorConfig):
+ column_names: list[str]
+ processor_type: Literal[ProcessorType.DROP_COLUMNS] = ProcessorType.DROP_COLUMNS
diff --git a/src/data_designer/config/sampler_constraints.py b/src/data_designer/config/sampler_constraints.py
index d5036616..9dea2bae 100644
--- a/src/data_designer/config/sampler_constraints.py
+++ b/src/data_designer/config/sampler_constraints.py
@@ -3,7 +3,9 @@
from abc import ABC, abstractmethod
from enum import Enum
-from typing import TypeAlias
+from typing import Union
+
+from typing_extensions import TypeAlias
from .base import ConfigBase
@@ -46,4 +48,4 @@ def constraint_type(self) -> ConstraintType:
return ConstraintType.COLUMN_INEQUALITY
-ColumnConstraintT: TypeAlias = ScalarInequalityConstraint | ColumnInequalityConstraint
+ColumnConstraintT: TypeAlias = Union[ScalarInequalityConstraint, ColumnInequalityConstraint]
diff --git a/src/data_designer/config/sampler_params.py b/src/data_designer/config/sampler_params.py
index 6f3abf9e..03b278e2 100644
--- a/src/data_designer/config/sampler_params.py
+++ b/src/data_designer/config/sampler_params.py
@@ -2,11 +2,11 @@
# SPDX-License-Identifier: Apache-2.0
from enum import Enum
-from typing import Literal, TypeAlias
+from typing import Literal, Optional, Union
import pandas as pd
from pydantic import Field, field_validator, model_validator
-from typing_extensions import Self
+from typing_extensions import Self, TypeAlias
from .base import ConfigBase
from .utils.constants import (
@@ -41,12 +41,12 @@ class SamplerType(str, Enum):
class CategorySamplerParams(ConfigBase):
- values: list[str | int | float] = Field(
+ values: list[Union[str, int, float]] = Field(
...,
min_length=1,
description="List of possible categorical values that can be sampled from.",
)
- weights: list[float] | None = Field(
+ weights: Optional[list[float]] = Field(
default=None,
description=(
"List of unnormalized probability weights to assigned to each value, in order. "
@@ -87,7 +87,7 @@ def _validate_param_is_datetime(cls, value: str) -> str:
class SubcategorySamplerParams(ConfigBase):
category: str = Field(..., description="Name of parent category to this subcategory.")
- values: dict[str, list[str | int | float]] = Field(
+ values: dict[str, list[Union[str, int, float]]] = Field(
...,
description="Mapping from each value of parent category to a list of subcategory values.",
)
@@ -127,7 +127,7 @@ def _validate_min_less_than_max(self) -> Self:
class UUIDSamplerParams(ConfigBase):
- prefix: str | None = Field(default=None, description="String prepended to the front of the UUID.")
+ prefix: Optional[str] = Field(default=None, description="String prepended to the front of the UUID.")
short_form: bool = Field(
default=False,
description="If true, all UUIDs sampled will be truncated at 8 characters.",
@@ -153,7 +153,7 @@ class ScipySamplerParams(ConfigBase):
...,
description="Parameters of the scipy.stats distribution given in `dist_name`.",
)
- decimal_places: int | None = Field(
+ decimal_places: Optional[int] = Field(
default=None, description="Number of decimal places to round the sampled values to."
)
@@ -191,7 +191,7 @@ class BernoulliMixtureSamplerParams(ConfigBase):
class GaussianSamplerParams(ConfigBase):
mean: float = Field(..., description="Mean of the Gaussian distribution")
stddev: float = Field(..., description="Standard deviation of the Gaussian distribution")
- decimal_places: int | None = Field(
+ decimal_places: Optional[int] = Field(
default=None, description="Number of decimal places to round the sampled values to."
)
@@ -203,7 +203,7 @@ class PoissonSamplerParams(ConfigBase):
class UniformSamplerParams(ConfigBase):
low: float = Field(..., description="Lower bound of the uniform distribution, inclusive.")
high: float = Field(..., description="Upper bound of the uniform distribution, inclusive.")
- decimal_places: int | None = Field(
+ decimal_places: Optional[int] = Field(
default=None, description="Number of decimal places to round the sampled values to."
)
@@ -223,11 +223,11 @@ class PersonSamplerParams(ConfigBase):
"that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR, ..."
),
)
- sex: SexT | None = Field(
+ sex: Optional[SexT] = Field(
default=None,
description="If specified, then only synthetic people of the specified sex will be sampled.",
)
- city: str | list[str] | None = Field(
+ city: Optional[Union[str, list[str]]] = Field(
default=None,
description="If specified, then only synthetic people from these cities will be sampled.",
)
@@ -238,7 +238,7 @@ class PersonSamplerParams(ConfigBase):
max_length=2,
)
- state: str | list[str] | None = Field(
+ state: Optional[Union[str, list[str]]] = Field(
default=None,
description=(
"Only supported for 'en_US' locale. If specified, then only synthetic people "
@@ -265,7 +265,8 @@ def generator_kwargs(self) -> list[str]:
def people_gen_key(self) -> str:
if self.locale in LOCALES_WITH_MANAGED_DATASETS and self.sample_dataset_when_available:
return f"{self.locale}_with_personas" if self.with_synthetic_personas else self.locale
- return f"{self.locale}_faker"
+ else:
+ return f"{self.locale}_faker"
@field_validator("age_range")
@classmethod
@@ -321,21 +322,21 @@ def _validate_with_synthetic_personas(self) -> Self:
return self
-SamplerParamsT: TypeAlias = (
- SubcategorySamplerParams
- | CategorySamplerParams
- | DatetimeSamplerParams
- | PersonSamplerParams
- | TimeDeltaSamplerParams
- | UUIDSamplerParams
- | BernoulliSamplerParams
- | BernoulliMixtureSamplerParams
- | BinomialSamplerParams
- | GaussianSamplerParams
- | PoissonSamplerParams
- | UniformSamplerParams
- | ScipySamplerParams
-)
+SamplerParamsT: TypeAlias = Union[
+ SubcategorySamplerParams,
+ CategorySamplerParams,
+ DatetimeSamplerParams,
+ PersonSamplerParams,
+ TimeDeltaSamplerParams,
+ UUIDSamplerParams,
+ BernoulliSamplerParams,
+ BernoulliMixtureSamplerParams,
+ BinomialSamplerParams,
+ GaussianSamplerParams,
+ PoissonSamplerParams,
+ UniformSamplerParams,
+ ScipySamplerParams,
+]
def is_numerical_sampler_type(sampler_type: SamplerType) -> bool:
diff --git a/src/data_designer/config/utils/__init__.py b/src/data_designer/config/utils/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/config/utils/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/config/utils/code_lang.py b/src/data_designer/config/utils/code_lang.py
index 4f1af4c8..c0621d36 100644
--- a/src/data_designer/config/utils/code_lang.py
+++ b/src/data_designer/config/utils/code_lang.py
@@ -4,6 +4,7 @@
from __future__ import annotations
from enum import Enum
+from typing import Union
class CodeLang(str, Enum):
@@ -25,17 +26,17 @@ class CodeLang(str, Enum):
SQL_ANSI = "sql:ansi"
@staticmethod
- def parse(value: str | CodeLang) -> tuple[str, str | None]:
+ def parse(value: Union[str, CodeLang]) -> tuple[str, Union[str, None]]:
value = value.value if isinstance(value, CodeLang) else value
split_vals = value.split(":")
return (split_vals[0], split_vals[1] if len(split_vals) > 1 else None)
@staticmethod
- def parse_lang(value: str | CodeLang) -> str:
+ def parse_lang(value: Union[str, CodeLang]) -> str:
return CodeLang.parse(value)[0]
@staticmethod
- def parse_dialect(value: str | CodeLang) -> str | None:
+ def parse_dialect(value: Union[str, CodeLang]) -> Union[str, None]:
return CodeLang.parse(value)[1]
@staticmethod
@@ -57,7 +58,7 @@ def supported_values() -> set[str]:
##########################################################
-def code_lang_to_syntax_lexer(code_lang: CodeLang | str) -> str:
+def code_lang_to_syntax_lexer(code_lang: Union[CodeLang, str]) -> str:
"""Convert the code language to a syntax lexer for Pygments.
Reference: https://pygments.org/docs/lexers/
diff --git a/src/data_designer/config/utils/constants.py b/src/data_designer/config/utils/constants.py
index 59958991..7dd6f41a 100644
--- a/src/data_designer/config/utils/constants.py
+++ b/src/data_designer/config/utils/constants.py
@@ -11,7 +11,7 @@
DEFAULT_REPR_HTML_STYLE = "nord"
REPR_HTML_FIXED_WIDTH = 1000
-REPR_HTML_TEMPLATE = f"""
+REPR_HTML_TEMPLATE = """
{{highlighted_html}}
-"""
+""".format(fixed_width=REPR_HTML_FIXED_WIDTH)
class NordColor(Enum):
diff --git a/src/data_designer/config/utils/io_helpers.py b/src/data_designer/config/utils/io_helpers.py
index 96b7a9ff..b47d35c9 100644
--- a/src/data_designer/config/utils/io_helpers.py
+++ b/src/data_designer/config/utils/io_helpers.py
@@ -8,7 +8,7 @@
from numbers import Number
import os
from pathlib import Path
-from typing import Any
+from typing import Any, Union
import numpy as np
import pandas as pd
@@ -39,7 +39,8 @@ def read_parquet_dataset(path: Path) -> pd.DataFrame:
[pd.read_parquet(file, dtype_backend="pyarrow") for file in sorted(path.glob("*.parquet"))],
ignore_index=True,
)
- raise e
+ else:
+ raise e
def write_seed_dataset(dataframe: pd.DataFrame, file_path: Path) -> None:
@@ -62,7 +63,7 @@ def write_seed_dataset(dataframe: pd.DataFrame, file_path: Path) -> None:
dataframe.to_json(file_path, orient="records", lines=True)
-def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True) -> Path:
+def validate_dataset_file_path(file_path: Union[str, Path], should_exist: bool = True) -> Path:
"""Validate that a dataset file path has a valid extension and optionally exists.
Args:
@@ -82,7 +83,7 @@ def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True)
return file_path
-def smart_load_dataframe(dataframe: str | Path | pd.DataFrame) -> pd.DataFrame:
+def smart_load_dataframe(dataframe: Union[str, Path, pd.DataFrame]) -> pd.DataFrame:
"""Load a dataframe from file if a path is given, otherwise return the dataframe.
Args:
@@ -106,14 +107,15 @@ def smart_load_dataframe(dataframe: str | Path | pd.DataFrame) -> pd.DataFrame:
# Load the dataframe based on the file extension.
if ext == "csv":
return pd.read_csv(dataframe)
- if ext == "json":
+ elif ext == "json":
return pd.read_json(dataframe, lines=True)
- if ext == "parquet":
+ elif ext == "parquet":
return pd.read_parquet(dataframe)
- raise ValueError(f"Unsupported file format: {dataframe}")
+ else:
+ raise ValueError(f"Unsupported file format: {dataframe}")
-def smart_load_yaml(yaml_in: str | Path | dict) -> dict:
+def smart_load_yaml(yaml_in: Union[str, Path, dict]) -> dict:
"""Return the yaml config as a dict given flexible input types.
Args:
@@ -130,7 +132,8 @@ def smart_load_yaml(yaml_in: str | Path | dict) -> dict:
elif isinstance(yaml_in, str):
if yaml_in.endswith((".yaml", ".yml")) and not os.path.isfile(yaml_in):
raise FileNotFoundError(f"File not found: {yaml_in}")
- yaml_out = yaml.safe_load(yaml_in)
+ else:
+ yaml_out = yaml.safe_load(yaml_in)
else:
raise ValueError(
f"'{yaml_in}' is an invalid yaml config format. Valid options are: dict, yaml string, or yaml file path."
@@ -142,14 +145,17 @@ def smart_load_yaml(yaml_in: str | Path | dict) -> dict:
return yaml_out
-def serialize_data(data: dict | list | str | Number, **kwargs) -> str:
- if isinstance(data, dict) or isinstance(data, list):
+def serialize_data(data: Union[dict, list, str, Number], **kwargs) -> str:
+ if isinstance(data, dict):
+ return json.dumps(data, ensure_ascii=False, default=_convert_to_serializable, **kwargs)
+ elif isinstance(data, list):
return json.dumps(data, ensure_ascii=False, default=_convert_to_serializable, **kwargs)
- if isinstance(data, str):
+ elif isinstance(data, str):
return data
- if isinstance(data, Number):
+ elif isinstance(data, Number):
return str(data)
- raise ValueError(f"Invalid data type: {type(data)}")
+ else:
+ raise ValueError(f"Invalid data type: {type(data)}")
def _convert_to_serializable(obj: Any) -> Any:
diff --git a/src/data_designer/config/utils/misc.py b/src/data_designer/config/utils/misc.py
index 05e5609a..c6b55d29 100644
--- a/src/data_designer/config/utils/misc.py
+++ b/src/data_designer/config/utils/misc.py
@@ -5,6 +5,7 @@
from contextlib import contextmanager
import json
+from typing import Optional, Union
from jinja2 import TemplateSyntaxError, meta
from jinja2.sandbox import ImmutableSandboxedEnvironment
@@ -57,7 +58,9 @@ def get_prompt_template_keywords(template: str) -> set[str]:
return keywords
-def json_indent_list_of_strings(column_names: list[str], *, indent: int | str | None = None) -> list[str] | str | None:
+def json_indent_list_of_strings(
+ column_names: list[str], *, indent: Optional[Union[int, str]] = None
+) -> Optional[Union[list[str], str]]:
"""Convert a list of column names to a JSON string if the list is long.
This function helps keep Data Designer's __repr__ output clean and readable.
diff --git a/src/data_designer/config/utils/numerical_helpers.py b/src/data_designer/config/utils/numerical_helpers.py
index c4f8a197..7d227cd3 100644
--- a/src/data_designer/config/utils/numerical_helpers.py
+++ b/src/data_designer/config/utils/numerical_helpers.py
@@ -3,7 +3,7 @@
import numbers
from numbers import Number
-from typing import Any
+from typing import Any, Type
from .constants import REPORTING_PRECISION
@@ -18,7 +18,7 @@ def is_float(val: Any) -> bool:
def prepare_number_for_reporting(
value: Number,
- target_type: type[Number],
+ target_type: Type[Number],
precision: int = REPORTING_PRECISION,
) -> Number:
"""Ensure native python types and round to `precision` decimal digits."""
diff --git a/src/data_designer/config/utils/type_helpers.py b/src/data_designer/config/utils/type_helpers.py
index 847e1f4c..02b17bb6 100644
--- a/src/data_designer/config/utils/type_helpers.py
+++ b/src/data_designer/config/utils/type_helpers.py
@@ -3,7 +3,7 @@
from enum import Enum
import inspect
-from typing import Any
+from typing import Any, Type
from pydantic import BaseModel
@@ -11,7 +11,7 @@
from .errors import InvalidEnumValueError
-def get_sampler_params() -> dict[str, type[BaseModel]]:
+def get_sampler_params() -> dict[str, Type[BaseModel]]:
"""Returns a dictionary of sampler parameter classes."""
params_cls_list = [
params_cls
@@ -38,7 +38,7 @@ def get_sampler_params() -> dict[str, type[BaseModel]]:
return params_cls_dict
-def resolve_string_enum(enum_instance: Any, enum_type: type[Enum]) -> Enum:
+def resolve_string_enum(enum_instance: Any, enum_type: Type[Enum]) -> Enum:
if not issubclass(enum_type, Enum):
raise InvalidEnumValueError(f"🛑 `enum_type` must be a subclass of Enum. You provided: {enum_type}")
invalid_enum_value_error = InvalidEnumValueError(
@@ -47,7 +47,7 @@ def resolve_string_enum(enum_instance: Any, enum_type: type[Enum]) -> Enum:
)
if isinstance(enum_instance, enum_type):
return enum_instance
- if isinstance(enum_instance, str):
+ elif isinstance(enum_instance, str):
try:
return enum_type(enum_instance)
except ValueError:
diff --git a/src/data_designer/config/utils/validation.py b/src/data_designer/config/utils/validation.py
index a3328637..a3864d31 100644
--- a/src/data_designer/config/utils/validation.py
+++ b/src/data_designer/config/utils/validation.py
@@ -5,6 +5,7 @@
from enum import Enum
from string import Formatter
+from typing import Optional
from jinja2 import meta
from jinja2.sandbox import ImmutableSandboxedEnvironment
@@ -15,6 +16,7 @@
from rich.panel import Panel
from ..columns import ColumnConfigT, DataDesignerColumnType
+from ..processors import ProcessorConfig, ProcessorType
from ..validator_params import ValidatorType
from .constants import RICH_CONSOLE_THEME
from .misc import can_run_data_designer_locally
@@ -28,6 +30,7 @@ class ViolationType(str, Enum):
EXPRESSION_REFERENCE_MISSING = "expression_reference_missing"
F_STRING_SYNTAX = "f_string_syntax"
LOCAL_ONLY_COLUMN = "local_only_column"
+ INVALID_COLUMN = "invalid_column"
INVALID_MODEL_CONFIG = "invalid_model_config"
INVALID_REFERENCE = "invalid_reference"
PROMPT_WITHOUT_REFERENCES = "prompt_without_references"
@@ -39,7 +42,7 @@ class ViolationLevel(str, Enum):
class Violation(BaseModel):
- column: str | None = None
+ column: Optional[str] = None
type: ViolationType
message: str
level: ViolationLevel
@@ -51,6 +54,7 @@ def has_column(self) -> bool:
def validate_data_designer_config(
columns: list[ColumnConfigT],
+ processor_configs: list[ProcessorConfig],
allowed_references: list[str],
) -> list[Violation]:
violations = []
@@ -58,6 +62,7 @@ def validate_data_designer_config(
violations.extend(validate_code_validation(columns=columns))
violations.extend(validate_expression_references(columns=columns, allowed_references=allowed_references))
violations.extend(validate_columns_not_all_dropped(columns=columns))
+ violations.extend(validate_drop_columns_processor(columns=columns, processor_configs=processor_configs))
if not can_run_data_designer_locally():
violations.extend(validate_local_only_columns(columns=columns))
return violations
@@ -147,7 +152,7 @@ def validate_prompt_templates(
if (
prompt_type == "prompt"
and len(prompt_references) == 0
- and (not hasattr(column, "multi_modal_context") or column.multi_modal_context is None)
+ and (not hasattr(column, "multi_modal_context") or getattr(column, "multi_modal_context") is None)
):
message = (
f"The {prompt_type} template for '{column.name}' does not reference any columns. "
@@ -262,6 +267,27 @@ def validate_columns_not_all_dropped(
return []
+def validate_drop_columns_processor(
+ columns: list[ColumnConfigT],
+ processor_configs: list[ProcessorConfig],
+) -> list[Violation]:
+ all_column_names = set([c.name for c in columns])
+ for processor_config in processor_configs:
+ if processor_config.processor_type == ProcessorType.DROP_COLUMNS:
+ invalid_columns = set(processor_config.column_names) - all_column_names
+ if len(invalid_columns) > 0:
+ return [
+ Violation(
+ column=c,
+ type=ViolationType.INVALID_COLUMN,
+ message=f"Drop columns processor is configured to drop column '{c!r}', but the column is not defined.",
+ level=ViolationLevel.ERROR,
+ )
+ for c in invalid_columns
+ ]
+ return []
+
+
def validate_expression_references(
columns: list[ColumnConfigT],
allowed_references: list[str],
diff --git a/src/data_designer/config/utils/visualization.py b/src/data_designer/config/utils/visualization.py
index fbba948c..f245517c 100644
--- a/src/data_designer/config/utils/visualization.py
+++ b/src/data_designer/config/utils/visualization.py
@@ -7,7 +7,7 @@
from enum import Enum
from functools import cached_property
import json
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional, Union
import numpy as np
import pandas as pd
@@ -51,22 +51,23 @@ class WithRecordSamplerMixin:
def _record_sampler_dataset(self) -> pd.DataFrame:
if hasattr(self, "dataset") and self.dataset is not None and isinstance(self.dataset, pd.DataFrame):
return self.dataset
- if (
+ elif (
hasattr(self, "load_dataset")
and callable(self.load_dataset)
and (dataset := self.load_dataset()) is not None
and isinstance(dataset, pd.DataFrame)
):
return dataset
- raise DatasetSampleDisplayError("No valid dataset found in results object.")
+ else:
+ raise DatasetSampleDisplayError("No valid dataset found in results object.")
def display_sample_record(
self,
- index: int | None = None,
+ index: Optional[int] = None,
*,
hide_seed_columns: bool = False,
syntax_highlighting_theme: str = "dracula",
- background_color: str | None = None,
+ background_color: Optional[str] = None,
) -> None:
"""Display a sample record from the Data Designer dataset preview.
@@ -100,11 +101,11 @@ def display_sample_record(
def create_rich_histogram_table(
- data: dict[str, int | float],
+ data: dict[str, Union[int, float]],
column_names: tuple[int, int],
name_style: str = ColorPalette.BLUE.value,
value_style: str = ColorPalette.TEAL.value,
- title: str | None = None,
+ title: Optional[str] = None,
**kwargs,
) -> Table:
table = Table(title=title, **kwargs)
@@ -120,11 +121,11 @@ def create_rich_histogram_table(
def display_sample_record(
- record: dict | pd.Series | pd.DataFrame,
+ record: Union[dict, pd.Series, pd.DataFrame],
config_builder: DataDesignerConfigBuilder,
- background_color: str | None = None,
+ background_color: Optional[str] = None,
syntax_highlighting_theme: str = "dracula",
- record_index: int | None = None,
+ record_index: Optional[int] = None,
hide_seed_columns: bool = False,
):
if isinstance(record, (dict, pd.Series)):
@@ -227,7 +228,7 @@ def display_sample_record(
def display_sampler_table(
sampler_params: dict[SamplerType, ConfigBase],
- title: str | None = None,
+ title: Optional[str] = None,
) -> None:
table = Table(expand=True)
table.add_column("Type")
@@ -285,7 +286,7 @@ def _get_field_type(field: dict) -> str:
return field["type"]
# union type
- if "anyOf" in field:
+ elif "anyOf" in field:
types = []
for f in field["anyOf"]:
if "$ref" in f:
diff --git a/src/data_designer/config/validator_params.py b/src/data_designer/config/validator_params.py
index c6b5e41e..0dedbda4 100644
--- a/src/data_designer/config/validator_params.py
+++ b/src/data_designer/config/validator_params.py
@@ -2,10 +2,10 @@
# SPDX-License-Identifier: Apache-2.0
from enum import Enum
-from typing import Any, TypeAlias
+from typing import Any, Optional, Union
from pydantic import Field, field_serializer, model_validator
-from typing_extensions import Self
+from typing_extensions import Self, TypeAlias
from .base import ConfigBase
from .utils.code_lang import SQL_DIALECTS, CodeLang
@@ -35,7 +35,7 @@ class LocalCallableValidatorParams(ConfigBase):
validation_function: Any = Field(
description="Function (Callable[[pd.DataFrame], pd.DataFrame]) to validate the data"
)
- output_schema: dict[str, Any] | None = Field(
+ output_schema: Optional[dict[str, Any]] = Field(
default=None, description="Expected schema for local callable validator's output"
)
@@ -52,7 +52,7 @@ def validate_validation_function(self) -> Self:
class RemoteValidatorParams(ConfigBase):
endpoint_url: str = Field(description="URL of the remote endpoint")
- output_schema: dict[str, Any] | None = Field(
+ output_schema: Optional[dict[str, Any]] = Field(
default=None, description="Expected schema for remote validator's output"
)
timeout: float = Field(default=30.0, gt=0, description="The timeout for the HTTP request")
@@ -61,4 +61,8 @@ class RemoteValidatorParams(ConfigBase):
max_parallel_requests: int = Field(default=4, ge=1, description="The maximum number of parallel requests to make")
-ValidatorParamsT: TypeAlias = CodeValidatorParams | LocalCallableValidatorParams | RemoteValidatorParams
+ValidatorParamsT: TypeAlias = Union[
+ CodeValidatorParams,
+ LocalCallableValidatorParams,
+ RemoteValidatorParams,
+]
diff --git a/src/data_designer/engine/analysis/__init__.py b/src/data_designer/engine/analysis/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/engine/analysis/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/engine/analysis/column_profilers/__init__.py b/src/data_designer/engine/analysis/column_profilers/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/engine/analysis/column_profilers/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py b/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py
index 911000a0..37de5dbd 100644
--- a/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py
+++ b/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py
@@ -5,6 +5,7 @@
import logging
import random
+from typing import Union
from data_designer.config.analysis.column_profilers import (
JudgeScoreProfilerConfig,
@@ -95,7 +96,7 @@ def _summarize_score_sample(
name: str,
sample: list[JudgeScoreSample],
histogram: CategoricalHistogramData,
- distribution: CategoricalDistribution | NumericalDistribution | MissingValue,
+ distribution: Union[CategoricalDistribution, NumericalDistribution, MissingValue],
distribution_type: ColumnDistributionType,
) -> JudgeScoreSummary:
if isinstance(distribution, MissingValue) or not sample:
@@ -107,7 +108,7 @@ def _summarize_score_sample(
category_info = []
total_count = sum(histogram.counts)
- for cat, count in zip(histogram.categories, histogram.counts, strict=False):
+ for cat, count in zip(histogram.categories, histogram.counts):
percentage = (count / total_count) * 100
category_info.append(f"{cat}: {count} records ({percentage:.1f}%)")
diff --git a/src/data_designer/engine/analysis/column_statistics.py b/src/data_designer/engine/analysis/column_statistics.py
index c01a3737..dd4fb1e9 100644
--- a/src/data_designer/engine/analysis/column_statistics.py
+++ b/src/data_designer/engine/analysis/column_statistics.py
@@ -4,7 +4,7 @@
from __future__ import annotations
import logging
-from typing import Any, TypeAlias
+from typing import Any, Type, TypeAlias, Union
import pandas as pd
from pydantic import BaseModel
@@ -49,7 +49,7 @@ def df(self) -> pd.DataFrame:
return self.column_config_with_df.df
@property
- def column_statistics_type(self) -> type[ColumnStatisticsT]:
+ def column_statistics_type(self) -> Type[ColumnStatisticsT]:
return DEFAULT_COLUMN_STATISTICS_MAP.get(self.column_config.column_type, GeneralColumnStatistics)
def calculate(self) -> Self:
@@ -146,17 +146,17 @@ class ExpressionColumnStatisticsCalculator(GeneralColumnStatisticsCalculator): .
}
-ColumnStatisticsCalculatorT: TypeAlias = (
- ExpressionColumnStatisticsCalculator
- | ValidationColumnStatisticsCalculator
- | GeneralColumnStatisticsCalculator
- | LLMCodeColumnStatisticsCalculator
- | LLMJudgedColumnStatisticsCalculator
- | LLMStructuredColumnStatisticsCalculator
- | LLMTextColumnStatisticsCalculator
- | SamplerColumnStatisticsCalculator
- | SeedDatasetColumnStatisticsCalculator
-)
+ColumnStatisticsCalculatorT: TypeAlias = Union[
+ ExpressionColumnStatisticsCalculator,
+ ValidationColumnStatisticsCalculator,
+ GeneralColumnStatisticsCalculator,
+ LLMCodeColumnStatisticsCalculator,
+ LLMJudgedColumnStatisticsCalculator,
+ LLMStructuredColumnStatisticsCalculator,
+ LLMTextColumnStatisticsCalculator,
+ SamplerColumnStatisticsCalculator,
+ SeedDatasetColumnStatisticsCalculator,
+]
DEFAULT_COLUMN_STATISTICS_CALCULATOR_MAP = {
DataDesignerColumnType.EXPRESSION: ExpressionColumnStatisticsCalculator,
DataDesignerColumnType.VALIDATION: ValidationColumnStatisticsCalculator,
diff --git a/src/data_designer/engine/analysis/reporting/__init__.py b/src/data_designer/engine/analysis/reporting/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/engine/analysis/reporting/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/engine/analysis/utils/__init__.py b/src/data_designer/engine/analysis/utils/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/engine/analysis/utils/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/engine/analysis/utils/judge_score_processing.py b/src/data_designer/engine/analysis/utils/judge_score_processing.py
index 8190e9f6..95b01ccc 100644
--- a/src/data_designer/engine/analysis/utils/judge_score_processing.py
+++ b/src/data_designer/engine/analysis/utils/judge_score_processing.py
@@ -3,7 +3,7 @@
from collections import defaultdict
import logging
-from typing import Any
+from typing import Any, Optional, Union
import pandas as pd
@@ -21,7 +21,7 @@
def extract_judge_score_distributions(
column_config: LLMJudgeColumnConfig, df: pd.DataFrame
-) -> JudgeScoreDistributions | MissingValue:
+) -> Union[JudgeScoreDistributions, MissingValue]:
scores = defaultdict(list)
reasoning = defaultdict(list)
@@ -79,10 +79,10 @@ def extract_judge_score_distributions(
def sample_scores_and_reasoning(
- scores: list[int | str],
+ scores: list[Union[int, str]],
reasoning: list[str],
num_samples: int,
- random_seed: int | None = None,
+ random_seed: Optional[int] = None,
) -> list[JudgeScoreSample]:
if len(scores) != len(reasoning):
raise ValueError("scores and reasoning must have the same length")
@@ -96,10 +96,7 @@ def sample_scores_and_reasoning(
df_samples = pd.DataFrame({"score": scores, "reasoning": reasoning})
if len(scores) <= num_samples:
- return [
- JudgeScoreSample(score=score, reasoning=reasoning)
- for score, reasoning in zip(scores, reasoning, strict=False)
- ]
+ return [JudgeScoreSample(score=score, reasoning=reasoning) for score, reasoning in zip(scores, reasoning)]
# Sample maintaining original proportions from each category (int or str)
# Calculate the frequency of each score category
diff --git a/src/data_designer/engine/column_generators/generators/expression.py b/src/data_designer/engine/column_generators/generators/expression.py
index 2f465918..7da80e66 100644
--- a/src/data_designer/engine/column_generators/generators/expression.py
+++ b/src/data_designer/engine/column_generators/generators/expression.py
@@ -50,11 +50,11 @@ def generate(self, data: pd.DataFrame) -> pd.DataFrame:
def _cast_type(self, value: str) -> str | float | int | bool:
if self.config.dtype == "str":
return value
- if self.config.dtype == "float":
+ elif self.config.dtype == "float":
return float(value)
- if self.config.dtype == "int":
+ elif self.config.dtype == "int":
return int(float(value))
- if self.config.dtype == "bool":
+ elif self.config.dtype == "bool":
try:
return bool(int(float(value)))
except ValueError:
diff --git a/src/data_designer/engine/column_generators/generators/samplers.py b/src/data_designer/engine/column_generators/generators/samplers.py
index 3f274292..d28e0a61 100644
--- a/src/data_designer/engine/column_generators/generators/samplers.py
+++ b/src/data_designer/engine/column_generators/generators/samplers.py
@@ -1,10 +1,10 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
-from collections.abc import Callable
from functools import partial
import logging
import random
+from typing import Callable
import pandas as pd
diff --git a/src/data_designer/engine/column_generators/generators/validation.py b/src/data_designer/engine/column_generators/generators/validation.py
index 504a2ab6..424165fc 100644
--- a/src/data_designer/engine/column_generators/generators/validation.py
+++ b/src/data_designer/engine/column_generators/generators/validation.py
@@ -35,7 +35,7 @@ def get_validator_from_params(validator_type: ValidatorType, validator_params: V
if validator_type == ValidatorType.CODE:
if validator_params.code_lang == CodeLang.PYTHON:
return PythonValidator(validator_params)
- if validator_params.code_lang in SQL_DIALECTS:
+ elif validator_params.code_lang in SQL_DIALECTS:
return SQLValidator(validator_params)
elif validator_type == ValidatorType.REMOTE:
return RemoteValidator(validator_params)
diff --git a/src/data_designer/engine/column_generators/utils/__init__.py b/src/data_designer/engine/column_generators/utils/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/engine/column_generators/utils/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/engine/column_generators/utils/judge_score_factory.py b/src/data_designer/engine/column_generators/utils/judge_score_factory.py
index 40212dbd..b4d458c6 100644
--- a/src/data_designer/engine/column_generators/utils/judge_score_factory.py
+++ b/src/data_designer/engine/column_generators/utils/judge_score_factory.py
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
from enum import Enum
+from typing import Type
from pydantic import BaseModel, ConfigDict, Field, create_model
@@ -18,7 +19,7 @@ class BaseJudgeResponse(BaseModel):
reasoning: str = Field(..., description="Reasoning for the assigned score.")
-def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str:
+def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str:
"""Convert score descriptions into a single text block."""
list_block = "\n".join(
[SCORING_FORMAT.format(score=score, description=description) for score, description in options.items()]
@@ -26,7 +27,7 @@ def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str:
return SCORE_FIELD_DESCRIPTION_FORMAT.format(enum_name=enum_type.__name__, scoring=list_block)
-def create_judge_response_model(score: Score) -> type[BaseJudgeResponse]:
+def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]:
"""Create a JudgeResponse data type."""
enum_members = {}
for option in score.options.keys():
@@ -45,8 +46,8 @@ def create_judge_response_model(score: Score) -> type[BaseJudgeResponse]:
def create_judge_structured_output_model(
- judge_responses: list[type[BaseJudgeResponse]],
-) -> type[BaseModel]:
+ judge_responses: list[Type[BaseJudgeResponse]],
+) -> Type[BaseModel]:
"""Create a JudgeStructuredOutput class dynamically."""
return create_model(
"JudgeStructuredOutput",
diff --git a/src/data_designer/engine/configurable_task.py b/src/data_designer/engine/configurable_task.py
index 94d8f799..0c3f10a5 100644
--- a/src/data_designer/engine/configurable_task.py
+++ b/src/data_designer/engine/configurable_task.py
@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import Generic, TypeVar
+from typing import Generic, Type, TypeVar
import pandas as pd
@@ -30,7 +30,7 @@ def __init__(self, config: TaskConfigT, *, resource_provider: ResourceProvider |
self._initialize()
@classmethod
- def get_config_type(cls) -> type[TaskConfigT]:
+ def get_config_type(cls) -> Type[TaskConfigT]:
for base in cls.__orig_bases__:
if hasattr(base, "__args__") and len(base.__args__) == 1 and issubclass(base.__args__[0], ConfigBase):
return base.__args__[0]
diff --git a/src/data_designer/engine/dataset_builders/__init__.py b/src/data_designer/engine/dataset_builders/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/engine/dataset_builders/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/engine/dataset_builders/artifact_storage.py b/src/data_designer/engine/dataset_builders/artifact_storage.py
index c29f4591..4a1e4919 100644
--- a/src/data_designer/engine/dataset_builders/artifact_storage.py
+++ b/src/data_designer/engine/dataset_builders/artifact_storage.py
@@ -6,6 +6,7 @@
import logging
from pathlib import Path
import shutil
+from typing import Union
import pandas as pd
from pydantic import BaseModel, field_validator, model_validator
@@ -57,7 +58,7 @@ def partial_results_path(self) -> Path:
return self.base_dataset_path / self.partial_results_folder_name
@field_validator("artifact_path")
- def validate_artifact_path(cls, v: Path | str) -> Path:
+ def validate_artifact_path(cls, v: Union[Path, str]) -> Path:
v = Path(v)
if not v.is_dir():
raise ArtifactStorageError("Artifact path must exist and be a directory")
diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py
index 430a56ea..f16a5a96 100644
--- a/src/data_designer/engine/dataset_builders/column_wise_builder.py
+++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py
@@ -1,20 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
-from collections.abc import Callable
import functools
import json
import logging
from pathlib import Path
import time
+from typing import Callable
import pandas as pd
from data_designer.config.columns import ColumnConfigT
+from data_designer.config.dataset_builders import BuildStage
+from data_designer.config.processors import (
+ DropColumnsProcessorConfig,
+ ProcessorConfig,
+ ProcessorType,
+)
from data_designer.engine.column_generators.generators.base import ColumnGenerator, GenerationStrategy
from data_designer.engine.column_generators.generators.llm_generators import WithLLMGeneration
-from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage, BatchStage
-from data_designer.engine.dataset_builders.errors import DatasetGenerationError
+from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
+from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
from data_designer.engine.dataset_builders.multi_column_configs import (
DatasetBuilderColumnConfigT,
MultiColumnConfig,
@@ -26,7 +32,7 @@
from data_designer.engine.dataset_builders.utils.dataset_batch_manager import (
DatasetBatchManager,
)
-from data_designer.engine.processing.processors.configs import DropColumnsProcessorConfig
+from data_designer.engine.processing.processors.base import Processor
from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
from data_designer.engine.resources.resource_provider import ResourceProvider
@@ -38,6 +44,7 @@ class ColumnWiseDatasetBuilder:
def __init__(
self,
column_configs: list[DatasetBuilderColumnConfigT],
+ processor_configs: list[ProcessorConfig],
resource_provider: ResourceProvider,
registry: DataDesignerRegistry | None = None,
):
@@ -46,6 +53,7 @@ def __init__(
self._records_to_drop: set[int] = set()
self._registry = registry or DataDesignerRegistry()
self._column_configs = column_configs
+ self._processors: dict[BuildStage, list[Processor]] = self._initialize_processors(processor_configs)
self._validate_column_configs()
@property
@@ -79,8 +87,12 @@ def build(
for batch_idx in range(1, self.batch_manager.num_batches + 1):
logger.info(f"⏳ Processing batch {batch_idx} of {self.batch_manager.num_batches}")
self._run_batch(generators)
- df_batch = self.batch_manager.get_current_batch(as_dataframe=True)
- self._write_processed_batch(self.drop_columns_if_needed(df_batch, save_dropped_columns=True))
+ df_batch = self._run_processors(
+ stage=BuildStage.POST_BATCH,
+ dataframe=self.batch_manager.get_current_batch(as_dataframe=True),
+ current_batch_number=batch_idx,
+ )
+ self._write_processed_batch(df_batch)
self.batch_manager.finish_batch(on_batch_complete)
self.batch_manager.finish()
@@ -99,7 +111,11 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame:
start_time = time.perf_counter()
self.batch_manager.start(num_records=num_records, buffer_size=num_records)
self._run_batch(generators, save_partial_results=False)
- dataset = self.batch_manager.get_current_batch(as_dataframe=True)
+ dataset = self._run_processors(
+ stage=BuildStage.POST_BATCH,
+ dataframe=self.batch_manager.get_current_batch(as_dataframe=True),
+ current_batch_number=None, # preview mode does not have a batch number
+ )
self.batch_manager.reset()
model_usage_stats = self._resource_provider.model_registry.get_model_usage_stats(
@@ -109,29 +125,6 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame:
return dataset
- def drop_columns_if_needed(self, dataframe: pd.DataFrame, *, save_dropped_columns: bool = False) -> pd.DataFrame:
- if len(columns_to_drop := [config.name for config in self.single_column_configs if config.drop]) == 0:
- return dataframe
- try:
- dropped_column_parquet_file_name = (
- None
- if not save_dropped_columns
- else self.artifact_storage.create_batch_file_path(
- batch_number=self.batch_manager.get_current_batch_number(),
- batch_stage=BatchStage.DROPPED_COLUMNS,
- ).name
- )
- df = DropColumnsProcessor(
- config=DropColumnsProcessorConfig(
- column_names=columns_to_drop,
- dropped_column_parquet_file_name=dropped_column_parquet_file_name,
- ),
- resource_provider=self._resource_provider,
- ).process(dataframe)
- return df
- except Exception as e:
- raise DatasetGenerationError(f"🛑 Failed to drop columns {columns_to_drop}: {e}")
-
def _initialize_generators(self) -> list[ColumnGenerator]:
return [
self._registry.column_generators.get_for_config_type(type(config))(
@@ -218,6 +211,51 @@ def _validate_column_configs(self) -> None:
).can_generate_from_scratch:
raise DatasetGenerationError("🛑 The first column config must be a from-scratch column generator.")
+ def _initialize_processors(self, processor_configs: list[ProcessorConfig]) -> dict[BuildStage, list[Processor]]:
+ # Check columns marked for drop
+ columns_to_drop = [config.name for config in self.single_column_configs if config.drop]
+
+ processors: dict[BuildStage, list[Processor]] = {stage: [] for stage in BuildStage}
+ for config in processor_configs:
+ processors[config.build_stage].append(
+ self._registry.processors.get_for_config_type(type(config))(
+ config=config,
+ resource_provider=self._resource_provider,
+ )
+ )
+
+ # Manually included "drop columns" processor takes precedence (can e.g., pick stages other than post-batch)
+ if config.processor_type == ProcessorType.DROP_COLUMNS:
+ for column in config.column_names:
+ if column in columns_to_drop:
+ columns_to_drop.remove(column)
+
+ # If there are still columns marked for drop, add the "drop columns" processor to drop them
+ if len(columns_to_drop) > 0:
+ processors[BuildStage.POST_BATCH].append( # as post-batch by default
+ DropColumnsProcessor(
+ config=DropColumnsProcessorConfig(
+ column_names=columns_to_drop,
+ build_stage=BuildStage.POST_BATCH,
+ ),
+ resource_provider=self._resource_provider,
+ )
+ )
+
+ return processors
+
+ def _run_processors(
+ self, stage: BuildStage, dataframe: pd.DataFrame, current_batch_number: int | None = None
+ ) -> pd.DataFrame:
+ for processor in self._processors[stage]:
+ try:
+ dataframe = processor.process(dataframe, current_batch_number=current_batch_number)
+ except Exception as e:
+ raise DatasetProcessingError(
+ f"🛑 Failed to process dataset with processor {processor.metadata().name} in stage {stage}: {e}"
+ ) from e
+ return dataframe
+
def _worker_error_callback(self, exc: Exception, *, context: dict | None = None) -> None:
"""If a worker fails, we can handle the exception here."""
logger.warning(
diff --git a/src/data_designer/engine/dataset_builders/errors.py b/src/data_designer/engine/dataset_builders/errors.py
index d99a387a..ee4b6138 100644
--- a/src/data_designer/engine/dataset_builders/errors.py
+++ b/src/data_designer/engine/dataset_builders/errors.py
@@ -8,3 +8,6 @@ class ArtifactStorageError(DataDesignerError): ...
class DatasetGenerationError(DataDesignerError): ...
+
+
+class DatasetProcessingError(DataDesignerError): ...
diff --git a/src/data_designer/engine/dataset_builders/utils/concurrency.py b/src/data_designer/engine/dataset_builders/utils/concurrency.py
index 3b305e06..b5760b6a 100644
--- a/src/data_designer/engine/dataset_builders/utils/concurrency.py
+++ b/src/data_designer/engine/dataset_builders/utils/concurrency.py
@@ -8,7 +8,7 @@
import json
import logging
from threading import Lock, Semaphore
-from typing import Any, Protocol
+from typing import Any, Optional, Protocol
from pydantic import BaseModel, Field
@@ -46,13 +46,13 @@ def is_error_rate_exceeded(self, window: int) -> bool:
class CallbackWithContext(Protocol):
"""Executor callback functions must accept a context kw argument."""
- def __call__(self, result: Any, *, context: dict | None = None) -> Any: ...
+ def __call__(self, result: Any, *, context: Optional[dict] = None) -> Any: ...
class ErrorCallbackWithContext(Protocol):
"""Error callbacks take the Exception instance and context."""
- def __call__(self, exc: Exception, *, context: dict | None = None) -> Any: ...
+ def __call__(self, exc: Exception, *, context: Optional[dict] = None) -> Any: ...
class ConcurrentThreadExecutor:
@@ -92,8 +92,8 @@ def __init__(
*,
max_workers: int,
column_name: str,
- result_callback: CallbackWithContext | None = None,
- error_callback: ErrorCallbackWithContext | None = None,
+ result_callback: Optional[CallbackWithContext] = None,
+ error_callback: Optional[ErrorCallbackWithContext] = None,
shutdown_error_rate: float = 0.50,
shutdown_error_window: int = 10,
):
@@ -136,7 +136,7 @@ def _raise_task_error(self):
)
)
- def submit(self, fn, *args, context: dict | None = None, **kwargs) -> None:
+ def submit(self, fn, *args, context: Optional[dict] = None, **kwargs) -> None:
if self._executor is None:
raise RuntimeError("Executor is not initialized, this class should be used as a context manager.")
diff --git a/src/data_designer/engine/dataset_builders/utils/config_compiler.py b/src/data_designer/engine/dataset_builders/utils/config_compiler.py
index e302a99e..2a784212 100644
--- a/src/data_designer/engine/dataset_builders/utils/config_compiler.py
+++ b/src/data_designer/engine/dataset_builders/utils/config_compiler.py
@@ -3,6 +3,7 @@
from data_designer.config.columns import DataDesignerColumnType
from data_designer.config.data_designer_config import DataDesignerConfig
+from data_designer.config.processors import ProcessorConfig
from data_designer.engine.dataset_builders.multi_column_configs import (
DatasetBuilderColumnConfigT,
SamplerMultiColumnConfig,
@@ -50,3 +51,9 @@ def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[D
compiled_column_configs.extend(generated_column_configs)
return compiled_column_configs
+
+
+def compile_dataset_builder_processor_configs(
+ config: DataDesignerConfig,
+) -> list[ProcessorConfig]:
+ return config.processors or []
diff --git a/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py b/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py
index 58826da5..240a3a64 100644
--- a/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py
+++ b/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py
@@ -1,10 +1,10 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
-from collections.abc import Callable, Container, Iterator
import logging
from pathlib import Path
import shutil
+from typing import Callable, Container, Iterator
import pandas as pd
import pyarrow.parquet as pq
diff --git a/src/data_designer/engine/model_provider.py b/src/data_designer/engine/model_provider.py
index deceef7c..e25f2e50 100644
--- a/src/data_designer/engine/model_provider.py
+++ b/src/data_designer/engine/model_provider.py
@@ -2,9 +2,9 @@
# SPDX-License-Identifier: Apache-2.0
from functools import cached_property
+from typing import Self
from pydantic import BaseModel, field_validator, model_validator
-from typing_extensions import Self
from data_designer.engine.errors import NoModelProvidersError, UnknownProviderError
@@ -53,7 +53,7 @@ def check_default_exists(self) -> Self:
raise ValueError(f"Specified default {self.default!r} not found in providers list")
return self
- def _get_default_provider_name(self) -> str:
+ def get_default_provider_name(self) -> str:
return self.default or self.providers[0].name
@cached_property
@@ -62,7 +62,7 @@ def _providers_dict(self) -> dict[str, ModelProvider]:
def get_provider(self, name: str | None) -> ModelProvider:
if name is None:
- name = self._get_default_provider_name()
+ name = self.get_default_provider_name()
try:
return self._providers_dict[name]
diff --git a/src/data_designer/engine/models/errors.py b/src/data_designer/engine/models/errors.py
index 242ddfbe..d725fa83 100644
--- a/src/data_designer/engine/models/errors.py
+++ b/src/data_designer/engine/models/errors.py
@@ -44,7 +44,8 @@ def get_exception_primary_cause(exception: BaseException) -> BaseException:
"""
if exception.__cause__ is None:
return exception
- return get_exception_primary_cause(exception.__cause__)
+ else:
+ return get_exception_primary_cause(exception.__cause__)
class GenerationValidationFailureError(Exception): ...
diff --git a/src/data_designer/engine/models/litellm_overrides.py b/src/data_designer/engine/models/litellm_overrides.py
index 78f9acae..41208a0d 100644
--- a/src/data_designer/engine/models/litellm_overrides.py
+++ b/src/data_designer/engine/models/litellm_overrides.py
@@ -5,6 +5,7 @@
import random
import threading
+from typing import Optional, Union
import httpx
import litellm
@@ -89,7 +90,7 @@ def __init__(
self._initial_retry_after_s = initial_retry_after_s
self._jitter_pct = jitter_pct
- def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None:
+ def _extract_retry_delay_from_headers(self, e: Exception) -> Optional[Union[int, float]]:
"""
Most of this code logic was extracted directly from the parent
`Router`'s `_time_to_sleep_before_retry` function. Our override
@@ -98,7 +99,7 @@ def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None:
return this info, we'll simply use that retry value returned here.
"""
- response_headers: httpx.Headers | None = None
+ response_headers: Optional[httpx.Headers] = None
if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
response_headers = e.response.headers # type: ignore
if hasattr(e, "litellm_response_headers"):
@@ -109,7 +110,8 @@ def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None:
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
if retry_after is not None and 0 < retry_after <= 60:
return retry_after
- return None
+ else:
+ return None
@override
def _time_to_sleep_before_retry(
@@ -117,9 +119,9 @@ def _time_to_sleep_before_retry(
e: Exception,
remaining_retries: int,
num_retries: int,
- healthy_deployments: list | None = None,
- all_deployments: list | None = None,
- ) -> int | float:
+ healthy_deployments: Optional[list] = None,
+ all_deployments: Optional[list] = None,
+ ) -> Union[int, float]:
"""
Implements exponential backoff for retries.
diff --git a/src/data_designer/engine/models/parsers/errors.py b/src/data_designer/engine/models/parsers/errors.py
index 087e8a83..cf0fd411 100644
--- a/src/data_designer/engine/models/parsers/errors.py
+++ b/src/data_designer/engine/models/parsers/errors.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
+from typing import Optional
+
class ParserException(Exception):
"""Identifies errors resulting from generic parser errors.
@@ -10,7 +12,7 @@ class ParserException(Exception):
attempted to parse.
"""
- source: str | None
+ source: Optional[str]
@staticmethod
def _log_format(source: str) -> str:
@@ -22,7 +24,7 @@ def _log_format(source: str) -> str:
# return f"{source}"
return ""
- def __init__(self, msg: str | None = None, source: str | None = None):
+ def __init__(self, msg: Optional[str] = None, source: Optional[str] = None):
msg = "" if msg is None else msg.strip()
if source is not None:
diff --git a/src/data_designer/engine/models/parsers/parser.py b/src/data_designer/engine/models/parsers/parser.py
index 3037978a..62acdfc2 100644
--- a/src/data_designer/engine/models/parsers/parser.py
+++ b/src/data_designer/engine/models/parsers/parser.py
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
from functools import reduce
+from typing import Optional
from lxml import etree
from lxml.etree import _Element
@@ -104,8 +105,8 @@ def __call__(self, element: _Element) -> CodeBlock:
def __init__(
self,
- tag_parsers: dict[str, TagParser] | None = None,
- postprocessors: list[PostProcessor] | None = None,
+ tag_parsers: Optional[dict[str, TagParser]] = None,
+ postprocessors: Optional[list[PostProcessor]] = None,
):
"""
Initializes the LLMResponseParser with optional tag parsers and post-processors.
diff --git a/src/data_designer/engine/models/parsers/postprocessors.py b/src/data_designer/engine/models/parsers/postprocessors.py
index 1cce5290..d7959505 100644
--- a/src/data_designer/engine/models/parsers/postprocessors.py
+++ b/src/data_designer/engine/models/parsers/postprocessors.py
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
+from typing import Optional, Type
import json_repair
from pydantic import BaseModel, ValidationError
@@ -59,12 +60,12 @@ def deserialize_json_code(
class RealizePydanticTypes:
- types: list[type[BaseModel]]
+ types: list[Type[BaseModel]]
- def __init__(self, types: list[type[BaseModel]]):
+ def __init__(self, types: list[Type[BaseModel]]):
self.types = types
- def _fit_types(self, obj: dict) -> BaseModel | None:
+ def _fit_types(self, obj: dict) -> Optional[BaseModel]:
final_obj = None
for t in self.types:
diff --git a/src/data_designer/engine/models/parsers/types.py b/src/data_designer/engine/models/parsers/types.py
index 1fa66647..17be38a2 100644
--- a/src/data_designer/engine/models/parsers/types.py
+++ b/src/data_designer/engine/models/parsers/types.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
-from typing import Any, Protocol, runtime_checkable
+from typing import Any, Optional, Protocol, Type, runtime_checkable
from lxml.etree import _Element
from pydantic import BaseModel, Field
@@ -30,7 +30,7 @@ def tail(self, n: int) -> Self:
out.parsed = out.parsed[-n:]
return out
- def filter(self, block_types: list[type[BaseModel]]) -> Self:
+ def filter(self, block_types: list[Type[BaseModel]]) -> Self:
out = self.model_copy()
out.parsed = [b for b in out.parsed if isinstance(b, tuple(block_types))]
return out
@@ -69,7 +69,7 @@ class TextBlock(BaseModel):
class CodeBlock(BaseModel):
code: str
- code_lang: str | None = None
+ code_lang: Optional[str] = None
class StructuredDataBlock(BaseModel):
diff --git a/src/data_designer/engine/models/recipes/__init__.py b/src/data_designer/engine/models/recipes/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/engine/models/recipes/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/engine/processing/__init__.py b/src/data_designer/engine/processing/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/engine/processing/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/engine/processing/ginja/ast.py b/src/data_designer/engine/processing/ginja/ast.py
index 9171365f..2d1fecb3 100644
--- a/src/data_designer/engine/processing/ginja/ast.py
+++ b/src/data_designer/engine/processing/ginja/ast.py
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
from collections import deque
+from typing import Optional, Type
from jinja2 import nodes as j_nodes
@@ -32,7 +33,7 @@ def ast_max_depth(node: j_nodes.Node) -> int:
return max_depth
-def ast_descendant_count(ast: j_nodes.Node, only_type: type[j_nodes.Node] | None = None) -> int:
+def ast_descendant_count(ast: j_nodes.Node, only_type: Optional[Type[j_nodes.Node]] = None) -> int:
"""Count the number of nodes which descend from the given node.
Args:
diff --git a/src/data_designer/engine/processing/ginja/exceptions.py b/src/data_designer/engine/processing/ginja/exceptions.py
index 9e796172..780ce358 100644
--- a/src/data_designer/engine/processing/ginja/exceptions.py
+++ b/src/data_designer/engine/processing/ginja/exceptions.py
@@ -43,9 +43,12 @@ def maybe_handle_missing_filter_exception(exception: BaseException, available_ji
match = re.search(r"No filter named '([^']+)'", exc_message)
if not match:
return
- missing_filter_name = match.group(1)
- available_filter_str = ", ".join(available_jinja_filters)
- raise UserTemplateUnsupportedFiltersError(
- f"The Jinja2 filter `{{{{ ... | {missing_filter_name} }}}}` "
- f"is not a permitted operation. Available filters: {available_filter_str}"
- ) from exception
+ else:
+ missing_filter_name = match.group(1)
+ available_filter_str = ", ".join(available_jinja_filters)
+ raise UserTemplateUnsupportedFiltersError(
+ (
+ f"The Jinja2 filter `{{{{ ... | {missing_filter_name} }}}}` "
+ f"is not a permitted operation. Available filters: {available_filter_str}"
+ )
+ ) from exception
diff --git a/src/data_designer/engine/processing/processors/__init__.py b/src/data_designer/engine/processing/processors/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/engine/processing/processors/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/engine/processing/processors/base.py b/src/data_designer/engine/processing/processors/base.py
index 98b342f8..48c06576 100644
--- a/src/data_designer/engine/processing/processors/base.py
+++ b/src/data_designer/engine/processing/processors/base.py
@@ -12,4 +12,4 @@ class Processor(ConfigurableTask[TaskConfigT], ABC):
def metadata() -> ConfigurableTaskMetadata: ...
@abstractmethod
- def process(self, data: DataT) -> DataT: ...
+ def process(self, data: DataT, *, current_batch_number: int | None = None) -> DataT: ...
diff --git a/src/data_designer/engine/processing/processors/configs.py b/src/data_designer/engine/processing/processors/configs.py
deleted file mode 100644
index 25651836..00000000
--- a/src/data_designer/engine/processing/processors/configs.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
-from pydantic import field_validator
-
-from data_designer.config.base import ConfigBase
-
-
-class DropColumnsProcessorConfig(ConfigBase):
- column_names: list[str]
- dropped_column_parquet_file_name: str | None = None
-
- @field_validator("dropped_column_parquet_file_name")
- def validate_dropped_column_parquet_file_name(cls, v: str | None) -> str | None:
- if v is not None and not v.endswith(".parquet"):
- raise ValueError("Dropped column parquet file name must end with .parquet")
- return v
diff --git a/src/data_designer/engine/processing/processors/drop_columns.py b/src/data_designer/engine/processing/processors/drop_columns.py
index 88dd2943..bc996a9e 100644
--- a/src/data_designer/engine/processing/processors/drop_columns.py
+++ b/src/data_designer/engine/processing/processors/drop_columns.py
@@ -5,10 +5,10 @@
import pandas as pd
+from data_designer.config.processors import DropColumnsProcessorConfig
from data_designer.engine.configurable_task import ConfigurableTaskMetadata
from data_designer.engine.dataset_builders.artifact_storage import BatchStage
from data_designer.engine.processing.processors.base import Processor
-from data_designer.engine.processing.processors.configs import DropColumnsProcessorConfig
logger = logging.getLogger(__name__)
@@ -22,9 +22,10 @@ def metadata() -> ConfigurableTaskMetadata:
required_resources=None,
)
- def process(self, data: pd.DataFrame) -> pd.DataFrame:
+ def process(self, data: pd.DataFrame, *, current_batch_number: int | None = None) -> pd.DataFrame:
logger.info(f"🙈 Dropping columns: {self.config.column_names}")
- self._save_dropped_columns_if_needed(data)
+ if current_batch_number is not None: # not in preview mode
+ self._save_dropped_columns_if_needed(data, current_batch_number)
for column in self.config.column_names:
if column in data.columns:
data.drop(columns=[column], inplace=True)
@@ -32,11 +33,14 @@ def process(self, data: pd.DataFrame) -> pd.DataFrame:
logger.warning(f"⚠️ Cannot drop column: `{column}` not found in the dataset.")
return data
- def _save_dropped_columns_if_needed(self, data: pd.DataFrame) -> None:
- if self.config.dropped_column_parquet_file_name:
- logger.debug("📦 Saving dropped columns to dropped-columns directory")
- self.artifact_storage.write_parquet_file(
- parquet_file_name=self.config.dropped_column_parquet_file_name,
- dataframe=data[self.config.column_names],
- batch_stage=BatchStage.DROPPED_COLUMNS,
- )
+ def _save_dropped_columns_if_needed(self, data: pd.DataFrame, current_batch_number: int) -> None:
+ logger.debug("📦 Saving dropped columns to dropped-columns directory")
+ dropped_column_parquet_file_name = self.artifact_storage.create_batch_file_path(
+ batch_number=current_batch_number,
+ batch_stage=BatchStage.DROPPED_COLUMNS,
+ ).name
+ self.artifact_storage.write_parquet_file(
+ parquet_file_name=dropped_column_parquet_file_name,
+ dataframe=data[self.config.column_names],
+ batch_stage=BatchStage.DROPPED_COLUMNS,
+ )
diff --git a/src/data_designer/engine/processing/processors/registry.py b/src/data_designer/engine/processing/processors/registry.py
new file mode 100644
index 00000000..dadcbc33
--- /dev/null
+++ b/src/data_designer/engine/processing/processors/registry.py
@@ -0,0 +1,20 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from data_designer.config.base import ConfigBase
+from data_designer.config.processors import (
+ DropColumnsProcessorConfig,
+ ProcessorType,
+)
+from data_designer.engine.processing.processors.base import Processor
+from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
+from data_designer.engine.registry.base import TaskRegistry
+
+
+class ProcessorRegistry(TaskRegistry[str, Processor, ConfigBase]): ...
+
+
+def create_default_processor_registry() -> ProcessorRegistry:
+ registry = ProcessorRegistry()
+ registry.register(ProcessorType.DROP_COLUMNS, DropColumnsProcessor, DropColumnsProcessorConfig, False)
+ return registry
diff --git a/src/data_designer/engine/processing/utils.py b/src/data_designer/engine/processing/utils.py
index 2da38a69..3579b3bd 100644
--- a/src/data_designer/engine/processing/utils.py
+++ b/src/data_designer/engine/processing/utils.py
@@ -3,7 +3,7 @@
import json
import logging
-from typing import Any, TypeVar, overload
+from typing import Any, TypeVar, Union, overload
import pandas as pd
@@ -25,7 +25,7 @@ def concat_datasets(datasets: list[pd.DataFrame]) -> pd.DataFrame:
# Overloads to help static type checker better understand
# the input/output types of the deserialize_json_values function.
@overload
-def deserialize_json_values(data: str) -> dict[str, Any] | list[Any] | Any: ...
+def deserialize_json_values(data: str) -> Union[dict[str, Any], list[Any], Any]: ...
@overload
diff --git a/src/data_designer/engine/registry/__init__.py b/src/data_designer/engine/registry/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/engine/registry/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/engine/registry/base.py b/src/data_designer/engine/registry/base.py
index 5d435004..5f780940 100644
--- a/src/data_designer/engine/registry/base.py
+++ b/src/data_designer/engine/registry/base.py
@@ -3,7 +3,7 @@
from enum import StrEnum
import threading
-from typing import Any, Generic, TypeVar
+from typing import Any, Generic, Type, TypeVar
from data_designer.config.base import ConfigBase
from data_designer.engine.configurable_task import ConfigurableTask
@@ -16,14 +16,14 @@
class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
# registered type name -> type
- _registry: dict[EnumNameT, type[TaskT]] = {}
+ _registry: dict[EnumNameT, Type[TaskT]] = {}
# type -> registered type name
- _reverse_registry: dict[type[TaskT], EnumNameT] = {}
+ _reverse_registry: dict[Type[TaskT], EnumNameT] = {}
# registered type name -> config type
- _config_registry: dict[EnumNameT, type[TaskConfigT]] = {}
+ _config_registry: dict[EnumNameT, Type[TaskConfigT]] = {}
# config type -> registered type name
- _reverse_config_registry: dict[type[TaskConfigT], EnumNameT] = {}
+ _reverse_config_registry: dict[Type[TaskConfigT], EnumNameT] = {}
# all registries are singletons
_instance = None
@@ -33,8 +33,8 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]):
def register(
cls,
name: EnumNameT,
- task: type[TaskT],
- config: type[TaskConfigT],
+ task: Type[TaskT],
+ config: Type[TaskConfigT],
raise_on_collision: bool = True,
) -> None:
if cls._has_been_registered(name):
@@ -52,22 +52,22 @@ def register(
cls._reverse_config_registry[config] = name
@classmethod
- def get_task_type(cls, name: EnumNameT) -> type[TaskT]:
+ def get_task_type(cls, name: EnumNameT) -> Type[TaskT]:
cls._raise_if_not_registered(name, cls._registry)
return cls._registry[name]
@classmethod
- def get_config_type(cls, name: EnumNameT) -> type[TaskConfigT]:
+ def get_config_type(cls, name: EnumNameT) -> Type[TaskConfigT]:
cls._raise_if_not_registered(name, cls._config_registry)
return cls._config_registry[name]
@classmethod
- def get_registered_name(cls, task: type[TaskT]) -> EnumNameT:
+ def get_registered_name(cls, task: Type[TaskT]) -> EnumNameT:
cls._raise_if_not_registered(task, cls._reverse_registry)
return cls._reverse_registry[task]
@classmethod
- def get_for_config_type(cls, config: type[TaskConfigT]) -> type[TaskT]:
+ def get_for_config_type(cls, config: Type[TaskConfigT]) -> Type[TaskT]:
cls._raise_if_not_registered(config, cls._reverse_config_registry)
name = cls._reverse_config_registry[config]
return cls.get_task_type(name)
@@ -77,7 +77,7 @@ def _has_been_registered(cls, name: EnumNameT) -> bool:
return name in cls._registry
@classmethod
- def _raise_if_not_registered(cls, key: EnumNameT | type[TaskT] | type[TaskConfigT], mapping: dict) -> None:
+ def _raise_if_not_registered(cls, key: EnumNameT | Type[TaskT] | Type[TaskConfigT], mapping: dict) -> None:
if not (isinstance(key, StrEnum) or isinstance(key, str)):
cls._raise_if_not_type(key)
if key not in mapping:
diff --git a/src/data_designer/engine/registry/data_designer_registry.py b/src/data_designer/engine/registry/data_designer_registry.py
index c498995d..8ed2f0ba 100644
--- a/src/data_designer/engine/registry/data_designer_registry.py
+++ b/src/data_designer/engine/registry/data_designer_registry.py
@@ -9,6 +9,10 @@
ColumnGeneratorRegistry,
create_default_column_generator_registry,
)
+from data_designer.engine.processing.processors.registry import (
+ ProcessorRegistry,
+ create_default_processor_registry,
+)
class DataDesignerRegistry:
@@ -17,9 +21,11 @@ def __init__(
*,
column_generator_registry: ColumnGeneratorRegistry | None = None,
column_profiler_registry: ColumnProfilerRegistry | None = None,
+ processor_registry: ProcessorRegistry | None = None,
):
self._column_generator_registry = column_generator_registry or create_default_column_generator_registry()
self._column_profiler_registry = column_profiler_registry or create_default_column_profiler_registry()
+ self._processor_registry = processor_registry or create_default_processor_registry()
@property
def column_generators(self) -> ColumnGeneratorRegistry:
@@ -28,3 +34,7 @@ def column_generators(self) -> ColumnGeneratorRegistry:
@property
def column_profilers(self) -> ColumnProfilerRegistry:
return self._column_profiler_registry
+
+ @property
+ def processors(self) -> ProcessorRegistry:
+ return self._processor_registry
diff --git a/src/data_designer/engine/resources/__init__.py b/src/data_designer/engine/resources/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/engine/resources/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/engine/resources/managed_storage.py b/src/data_designer/engine/resources/managed_storage.py
index 946477a4..cd7cb0b8 100644
--- a/src/data_designer/engine/resources/managed_storage.py
+++ b/src/data_designer/engine/resources/managed_storage.py
@@ -118,7 +118,7 @@ def init_managed_blob_storage(assets_storage: str = "s3://gretel-managed-assets-
return S3BlobStorageProvider(bucket_name=bucket_name)
- if assets_storage.startswith("/"):
+ elif assets_storage.startswith("/"):
path = Path(assets_storage)
if not path.exists():
raise RuntimeError(f"Local storage path {assets_storage!r} does not exist.")
diff --git a/src/data_designer/engine/resources/seed_dataset_data_store.py b/src/data_designer/engine/resources/seed_dataset_data_store.py
index 4b206333..3295f41e 100644
--- a/src/data_designer/engine/resources/seed_dataset_data_store.py
+++ b/src/data_designer/engine/resources/seed_dataset_data_store.py
@@ -2,15 +2,10 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
-import os
-import tempfile
-from datasets import DatasetDict, load_dataset
import duckdb
from huggingface_hub import HfApi, HfFileSystem
-import pandas as pd
-from data_designer.config.utils.io_helpers import validate_dataset_file_path
from data_designer.logging import quiet_noisy_logger
quiet_noisy_logger("httpx")
@@ -31,9 +26,6 @@ def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: ...
@abstractmethod
def get_dataset_uri(self, file_id: str) -> str: ...
- @abstractmethod
- def load_dataset(self, file_id: str) -> pd.DataFrame: ...
-
class LocalSeedDatasetDataStore(SeedDatasetDataStore):
"""Local filesystem-based dataset storage."""
@@ -44,20 +36,6 @@ def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
def get_dataset_uri(self, file_id: str) -> str:
return file_id
- def load_dataset(self, file_id: str) -> pd.DataFrame:
- filepath = validate_dataset_file_path(file_id)
- match filepath.suffix.lower():
- case ".csv":
- return pd.read_csv(filepath)
- case ".parquet":
- return pd.read_parquet(filepath)
- case ".json":
- return pd.read_json(filepath, lines=True)
- case ".jsonl":
- return pd.read_json(filepath, lines=True)
- case _:
- raise ValueError("Local datasets must be CSV, Parquet, JSON, or JSONL")
-
class HfHubSeedDatasetDataStore(SeedDatasetDataStore):
"""Hugging Face and Data Store dataset storage."""
@@ -76,54 +54,6 @@ def get_dataset_uri(self, file_id: str) -> str:
repo_id, filename = self._get_repo_id_and_filename(identifier)
return f"{_HF_DATASETS_PREFIX}{repo_id}/{filename}"
- def load_dataset(self, file_id: str) -> pd.DataFrame:
- identifier = file_id.removeprefix(_HF_DATASETS_PREFIX)
- repo_id, filename = self._get_repo_id_and_filename(identifier)
- is_file = "." in file_id.split("/")[-1]
-
- self._validate_repo(repo_id)
-
- if is_file:
- self._validate_file(repo_id, filename)
- return self._download_and_load_file(repo_id, filename)
- return self._download_and_load_directory(repo_id, filename)
-
- def _validate_repo(self, repo_id: str) -> None:
- """Validate that the repository exists and is a dataset repo."""
- if not self.hfapi.repo_exists(repo_id, repo_type="dataset"):
- if self.hfapi.repo_exists(repo_id, repo_type="model"):
- raise FileNotFoundError(f"Repo {repo_id} is a model repo, not a dataset repo")
- raise FileNotFoundError(f"Repo {repo_id} does not exist")
-
- def _validate_file(self, repo_id: str, filename: str) -> None:
- """Validate that the file exists in the repository."""
- if not self.hfapi.file_exists(repo_id, filename, repo_type="dataset"):
- raise FileNotFoundError(f"File {filename} does not exist in repo {repo_id}")
-
- def _download_and_load_file(self, repo_id: str, filename: str) -> pd.DataFrame:
- """Download a specific file and load it as a dataset."""
- with tempfile.TemporaryDirectory() as temp_dir:
- self.hfapi.hf_hub_download(
- repo_id=repo_id,
- filename=filename,
- local_dir=temp_dir,
- repo_type="dataset",
- )
- return self._load_local_dataset(temp_dir)
-
- def _download_and_load_directory(self, repo_id: str, directory: str) -> pd.DataFrame:
- """Download entire repo and load from specific subdirectory."""
- with tempfile.TemporaryDirectory() as temp_dir:
- self.hfapi.snapshot_download(
- repo_id=repo_id,
- local_dir=temp_dir,
- repo_type="dataset",
- )
- dataset_path = os.path.join(temp_dir, directory)
- if not os.path.exists(dataset_path):
- dataset_path = temp_dir
- return self._load_local_dataset(dataset_path)
-
def _get_repo_id_and_filename(self, identifier: str) -> tuple[str, str]:
"""Extract repo_id and filename from identifier."""
parts = identifier.split("/", 2)
@@ -134,10 +64,3 @@ def _get_repo_id_and_filename(self, identifier: str) -> tuple[str, str]:
)
repo_ns, repo_name, filename = parts
return f"{repo_ns}/{repo_name}", filename
-
- def _load_local_dataset(self, path: str) -> pd.DataFrame:
- """Load dataset from local path."""
- hf_dataset = load_dataset(path=path)
- if isinstance(hf_dataset, DatasetDict):
- hf_dataset = hf_dataset[list(hf_dataset.keys())[0]]
- return hf_dataset.to_pandas()
diff --git a/src/data_designer/engine/sampling_gen/__init__.py b/src/data_designer/engine/sampling_gen/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/engine/sampling_gen/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/engine/sampling_gen/constraints.py b/src/data_designer/engine/sampling_gen/constraints.py
index 4d719de4..30ee688f 100644
--- a/src/data_designer/engine/sampling_gen/constraints.py
+++ b/src/data_designer/engine/sampling_gen/constraints.py
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
+from typing import Type
import numpy as np
from numpy.typing import NDArray
@@ -90,5 +91,5 @@ def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]:
}
-def get_constraint_checker(constraint_type: ConstraintType) -> type[ConstraintChecker]:
+def get_constraint_checker(constraint_type: ConstraintType) -> Type[ConstraintChecker]:
return CONSTRAINT_TYPE_TO_CHECKER[ConstraintType(constraint_type)]
diff --git a/src/data_designer/engine/sampling_gen/data_sources/__init__.py b/src/data_designer/engine/sampling_gen/data_sources/__init__.py
deleted file mode 100644
index 4ee5de4a..00000000
--- a/src/data_designer/engine/sampling_gen/data_sources/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-
diff --git a/src/data_designer/engine/sampling_gen/data_sources/base.py b/src/data_designer/engine/sampling_gen/data_sources/base.py
index 7d8961fd..acfd2e2a 100644
--- a/src/data_designer/engine/sampling_gen/data_sources/base.py
+++ b/src/data_designer/engine/sampling_gen/data_sources/base.py
@@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
-from typing import Any, Generic, TypeVar
+from typing import Any, Generic, Optional, Type, TypeVar, Union
import numpy as np
from numpy.typing import NDArray
@@ -45,7 +45,7 @@ def postproc(series: pd.Series, convert_to: str) -> pd.Series:
return series
@staticmethod
- def validate_data_conversion(convert_to: str | None) -> None:
+ def validate_data_conversion(convert_to: Optional[str]) -> None:
pass
@@ -71,7 +71,7 @@ def preproc(series: pd.Series, convert_to: str) -> pd.Series:
return series
@staticmethod
- def postproc(series: pd.Series, convert_to: str | None) -> pd.Series:
+ def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series:
if convert_to is not None:
if convert_to == "int":
series = series.round()
@@ -79,18 +79,18 @@ def postproc(series: pd.Series, convert_to: str | None) -> pd.Series:
return series
@staticmethod
- def validate_data_conversion(convert_to: str | None) -> None:
+ def validate_data_conversion(convert_to: Optional[str]) -> None:
if convert_to is not None and convert_to not in ["float", "int", "str"]:
raise ValueError(f"Invalid `convert_to` value: {convert_to}. Must be one of: [float, int, str]")
class DatetimeFormatMixin:
@staticmethod
- def preproc(series: pd.Series, convert_to: str | None) -> pd.Series:
+ def preproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series:
return series
@staticmethod
- def postproc(series: pd.Series, convert_to: str | None) -> pd.Series:
+ def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series:
if convert_to is not None:
return series.dt.strftime(convert_to)
if series.dt.month.nunique() == 1:
@@ -104,7 +104,7 @@ def postproc(series: pd.Series, convert_to: str | None) -> pd.Series:
return series.apply(lambda dt: dt.isoformat()).astype(str)
@staticmethod
- def validate_data_conversion(convert_to: str | None) -> None:
+ def validate_data_conversion(convert_to: Optional[str]) -> None:
if convert_to is not None:
try:
pd.to_datetime(pd.to_datetime("2012-12-21").strftime(convert_to))
@@ -121,7 +121,7 @@ class DataSource(ABC, Generic[GenericParamsT]):
def __init__(
self,
params: GenericParamsT,
- random_state: RadomStateT | None = None,
+ random_state: Optional[RadomStateT] = None,
**kwargs,
):
self.rng = check_random_state(random_state)
@@ -130,7 +130,7 @@ def __init__(
self._validate()
@classmethod
- def get_param_type(cls) -> type[GenericParamsT]:
+ def get_param_type(cls) -> Type[GenericParamsT]:
return cls.__orig_bases__[-1].__args__[0]
@abstractmethod
@@ -138,7 +138,7 @@ def inject_data_column(
self,
dataframe: pd.DataFrame,
column_name: str,
- index: list[int] | None = None,
+ index: Optional[list[int]] = None,
) -> pd.DataFrame: ...
@staticmethod
@@ -147,11 +147,11 @@ def preproc(series: pd.Series) -> pd.Series: ...
@staticmethod
@abstractmethod
- def postproc(series: pd.Series, convert_to: str | None) -> pd.Series: ...
+ def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series: ...
@staticmethod
@abstractmethod
- def validate_data_conversion(convert_to: str | None) -> None: ...
+ def validate_data_conversion(convert_to: Optional[str]) -> None: ...
def get_required_column_names(self) -> tuple[str, ...]:
return tuple()
@@ -182,7 +182,7 @@ def inject_data_column(
self,
dataframe: pd.DataFrame,
column_name: str,
- index: list[int] | None = None,
+ index: Optional[list[int]] = None,
) -> pd.DataFrame:
index = slice(None) if index is None else index
@@ -208,7 +208,7 @@ def sample(self, num_samples: int) -> NumpyArray1dT: ...
class ScipyStatsSampler(Sampler[GenericParamsT], ABC):
@property
@abstractmethod
- def distribution(self) -> stats.rv_continuous | stats.rv_discrete: ...
+ def distribution(self) -> Union[stats.rv_continuous, stats.rv_discrete]: ...
def sample(self, num_samples: int) -> NumpyArray1dT:
return self.distribution.rvs(size=num_samples, random_state=self.rng)
diff --git a/src/data_designer/engine/sampling_gen/entities/email_address_utils.py b/src/data_designer/engine/sampling_gen/entities/email_address_utils.py
index 17db6d2f..7dce2d3c 100644
--- a/src/data_designer/engine/sampling_gen/entities/email_address_utils.py
+++ b/src/data_designer/engine/sampling_gen/entities/email_address_utils.py
@@ -77,17 +77,18 @@ def get_email_domain_by_age(age: int) -> str:
weights=list(email_domains_under_30.values()),
k=1,
)[0]
- if age < 50:
+ elif age < 50:
return random.choices(
list(email_domains_30_50.keys()),
weights=list(email_domains_30_50.values()),
k=1,
)[0]
- return random.choices(
- list(email_domains_over_50.keys()),
- weights=list(email_domains_over_50.values()),
- k=1,
- )[0]
+ else:
+ return random.choices(
+ list(email_domains_over_50.keys()),
+ weights=list(email_domains_over_50.values()),
+ k=1,
+ )[0]
def get_email_basename_by_name(first_name: str, middle_name: str | None, last_name: str) -> str:
diff --git a/src/data_designer/engine/sampling_gen/entities/person.py b/src/data_designer/engine/sampling_gen/entities/person.py
index 1b2b1f15..685916f8 100644
--- a/src/data_designer/engine/sampling_gen/entities/person.py
+++ b/src/data_designer/engine/sampling_gen/entities/person.py
@@ -78,10 +78,10 @@ def generate_phone_number(locale: str, age: int, postcode: str | None, style: st
if locality_var < 0.6:
# Exact match to postcode 60% of the time
return PhoneNumber.from_zip_prefix(postcode).format(style=style)
- if locality_var < 0.8:
+ elif locality_var < 0.8:
# Nearby postcodes 20% of the time
return PhoneNumber.from_zip_prefix(postcode[:4]).format(style=style)
- if locality_var < 0.9:
+ elif locality_var < 0.9:
# More distant postcodes 10% of the time
return PhoneNumber.from_zip_prefix(postcode[:3]).format(style=style)
# Random (population-weighted) area code 10% of the time
diff --git a/src/data_designer/engine/sampling_gen/entities/phone_number.py b/src/data_designer/engine/sampling_gen/entities/phone_number.py
index 422c3644..20a7394c 100644
--- a/src/data_designer/engine/sampling_gen/entities/phone_number.py
+++ b/src/data_designer/engine/sampling_gen/entities/phone_number.py
@@ -3,16 +3,17 @@
from pathlib import Path
import random
+from typing import Optional
import pandas as pd
from pydantic import BaseModel, Field, field_validator
ZIP_AREA_CODE_DATA = pd.read_parquet(Path(__file__).parent / "assets" / "zip_area_code_map.parquet")
-ZIPCODE_AREA_CODE_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["area_code"], strict=False))
-ZIPCODE_POPULATION_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["count"], strict=False))
+ZIPCODE_AREA_CODE_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["area_code"]))
+ZIPCODE_POPULATION_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["count"]))
-def get_area_code(zip_prefix: str | None = None) -> str:
+def get_area_code(zip_prefix: Optional[str] = None) -> str:
"""
Sample an area code for the given ZIP code prefix, population-weighted.
@@ -23,7 +24,7 @@ def get_area_code(zip_prefix: str | None = None) -> str:
A sampled area code matching the prefix, population-weighted.
"""
if zip_prefix is None:
- zipcodes, weights = zip(*ZIPCODE_POPULATION_MAP.items(), strict=False)
+ zipcodes, weights = zip(*ZIPCODE_POPULATION_MAP.items())
zipcode = random.choices(zipcodes, weights=weights, k=1)[0]
return str(ZIPCODE_AREA_CODE_MAP[zipcode])
if len(zip_prefix) == 5:
@@ -32,7 +33,7 @@ def get_area_code(zip_prefix: str | None = None) -> str:
except KeyError:
raise ValueError(f"ZIP code {zip_prefix} not found.")
matching_zipcodes = [[z, c] for z, c in ZIPCODE_POPULATION_MAP.items() if z.startswith(zip_prefix)]
- zipcodes, weights = zip(*matching_zipcodes, strict=False)
+ zipcodes, weights = zip(*matching_zipcodes)
if not zipcodes:
raise ValueError(f"No ZIP codes found with prefix {zip_prefix}.")
zipcode = random.choices(zipcodes, weights=weights, k=1)[0]
diff --git a/src/data_designer/engine/secret_resolver.py b/src/data_designer/engine/secret_resolver.py
index 0da339eb..521064d4 100644
--- a/src/data_designer/engine/secret_resolver.py
+++ b/src/data_designer/engine/secret_resolver.py
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
+from collections.abc import Sequence
import json
import logging
import os
@@ -30,7 +31,7 @@ def resolve(self, secret: str) -> str:
try:
return self._secrets[secret]
except KeyError:
- raise SecretResolutionError(f"No secret found with key {secret!r}")
+ raise SecretResolutionError(f"No secret found in secrets file with key {secret!r}")
class EnvironmentResolver(SecretResolver):
@@ -47,18 +48,22 @@ def resolve(self, secret: str) -> str:
class CompositeResolver(SecretResolver):
- _resolvers: list[SecretResolver]
+ _resolvers: Sequence[SecretResolver]
- def __init__(self, resolvers: list[SecretResolver]):
+ def __init__(self, resolvers: Sequence[SecretResolver]):
if len(resolvers) == 0:
raise SecretResolutionError("Must provide at least one SecretResolver to CompositeResolver")
self._resolvers = resolvers
def resolve(self, secret: str) -> str:
+ errors = []
for resolver in self._resolvers:
try:
return resolver.resolve(secret)
- except SecretResolutionError:
+ except SecretResolutionError as err:
+ errors.append(str(err))
continue
- raise SecretResolutionError(f"No configured resolvers were able to resolve secret {secret!r}")
+ raise SecretResolutionError(
+ f"No configured resolvers were able to resolve secret {secret!r}: {', '.join(errors)}"
+ )
diff --git a/src/data_designer/engine/validators/base.py b/src/data_designer/engine/validators/base.py
index d18ab678..18f9597e 100644
--- a/src/data_designer/engine/validators/base.py
+++ b/src/data_designer/engine/validators/base.py
@@ -2,14 +2,13 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
-from collections.abc import Iterator
+from typing import Iterator, Optional, Self
from pydantic import BaseModel, ConfigDict
-from typing_extensions import Self
class ValidationOutput(BaseModel):
- is_valid: bool | None
+ is_valid: Optional[bool]
model_config = ConfigDict(extra="allow")
diff --git a/src/data_designer/engine/validators/python.py b/src/data_designer/engine/validators/python.py
index f71b678b..5b42265e 100644
--- a/src/data_designer/engine/validators/python.py
+++ b/src/data_designer/engine/validators/python.py
@@ -238,7 +238,7 @@ def _get_scores(stats_by_module: dict[str, dict[str, int]]) -> dict[str, float]:
def _count_python_statements(file_path: str) -> int:
"""Count the number of statements in a Python file."""
try:
- with open(file_path, encoding="utf-8") as f:
+ with open(file_path, "r", encoding="utf-8") as f:
tree = ast.parse(f.read())
return sum(1 for node in ast.walk(tree) if isinstance(node, ast.stmt))
except Exception:
diff --git a/src/data_designer/essentials/__init__.py b/src/data_designer/essentials/__init__.py
index 50e21493..3be1950f 100644
--- a/src/data_designer/essentials/__init__.py
+++ b/src/data_designer/essentials/__init__.py
@@ -16,6 +16,7 @@
)
from ..config.config_builder import DataDesignerConfigBuilder
from ..config.data_designer_config import DataDesignerConfig
+from ..config.dataset_builders import BuildStage
from ..config.datastore import DatastoreSettings
from ..config.models import (
ImageContext,
@@ -30,6 +31,7 @@
UniformDistribution,
UniformDistributionParams,
)
+from ..config.processors import DropColumnsProcessorConfig, ProcessorType
from ..config.sampler_constraints import ColumnInequalityConstraint, ScalarInequalityConstraint
from ..config.sampler_params import (
BernoulliMixtureSamplerParams,
@@ -80,9 +82,11 @@
"DataDesignerColumnType",
"DataDesignerConfig",
"DataDesignerConfigBuilder",
+ "BuildStage",
"DatastoreSeedDatasetReference",
"DatastoreSettings",
"DatetimeSamplerParams",
+ "DropColumnsProcessorConfig",
"ExpressionColumnConfig",
"GaussianSamplerParams",
"ImageContext",
@@ -102,6 +106,7 @@
"ModelConfig",
"PersonSamplerParams",
"PoissonSamplerParams",
+ "ProcessorType",
"RemoteValidatorParams",
"SamplerColumnConfig",
"SamplerType",
diff --git a/src/data_designer/interface/data_designer.py b/src/data_designer/interface/data_designer.py
index b0d79aef..13d16cb5 100644
--- a/src/data_designer/interface/data_designer.py
+++ b/src/data_designer/interface/data_designer.py
@@ -204,8 +204,6 @@ def preview(
except Exception as e:
raise DataDesignerProfilingError(f"🛑 Error profiling preview dataset: {e}")
- dataset = builder.drop_columns_if_needed(dataset)
-
if len(dataset) > 0 and isinstance(analysis, DatasetProfilerResults) and len(analysis.column_statistics) > 0:
logger.info(f"{RandomEmoji.success()} Preview complete!")
@@ -237,6 +235,7 @@ def _create_dataset_builder(
) -> ColumnWiseDatasetBuilder:
return ColumnWiseDatasetBuilder(
column_configs=compile_dataset_builder_column_configs(config_builder.build(raise_exceptions=True)),
+ processor_configs=config_builder.get_processor_configs(),
resource_provider=resource_provider,
)
diff --git a/src/data_designer/logging.py b/src/data_designer/logging.py
index e50bf752..77a8210d 100644
--- a/src/data_designer/logging.py
+++ b/src/data_designer/logging.py
@@ -6,7 +6,7 @@
from pathlib import Path
import random
import sys
-from typing import TextIO
+from typing import TextIO, Union
from pythonjsonlogger import jsonlogger
@@ -19,7 +19,7 @@ class LoggerConfig:
@dataclass
class OutputConfig:
- destination: TextIO | Path
+ destination: Union[TextIO, Path]
structured: bool
diff --git a/tests/config/analysis/utils/test_reporting.py b/tests/config/analysis/utils/test_reporting.py
index 10625a6a..ec9a9912 100644
--- a/tests/config/analysis/utils/test_reporting.py
+++ b/tests/config/analysis/utils/test_reporting.py
@@ -60,7 +60,7 @@ def test_generate_analysis_report_with_save_path_html(sample_dataset_profiler_re
assert Path(tmp_path).stat().st_size > 0
# Verify it's valid HTML by checking for basic HTML structure
- with open(tmp_path) as f:
+ with open(tmp_path, "r") as f:
content = f.read()
assert " 0
# Verify it's valid SVG by checking for SVG structure
- with open(tmp_path) as f:
+ with open(tmp_path, "r") as f:
content = f.read()
assert "