diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index 5f07595d..db382f88 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -29,7 +29,7 @@ from data_designer.config.default_model_settings import get_default_model_configs from data_designer.config.errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError from data_designer.config.models import ModelConfig, load_model_configs -from data_designer.config.processors import ProcessorConfig, ProcessorType, get_processor_config_from_kwargs +from data_designer.config.processors import ProcessorConfigT, ProcessorType, get_processor_config_from_kwargs from data_designer.config.sampler_constraints import ( ColumnConstraintT, ColumnInequalityConstraint, @@ -141,7 +141,7 @@ def __init__(self, model_configs: Optional[Union[list[ModelConfig], str, Path]] """ self._column_configs = {} self._model_configs = _load_model_configs(model_configs) - self._processor_configs: list[ProcessorConfig] = [] + self._processor_configs: list[ProcessorConfigT] = [] self._seed_config: Optional[SeedConfig] = None self._constraints: list[ColumnConstraintT] = [] self._profilers: list[ColumnProfilerConfigT] = [] @@ -298,7 +298,7 @@ def add_constraint( def add_processor( self, - processor_config: Optional[ProcessorConfig] = None, + processor_config: Optional[ProcessorConfigT] = None, *, processor_type: Optional[ProcessorType] = None, **kwargs, @@ -487,7 +487,7 @@ def get_columns_excluding_type(self, column_type: DataDesignerColumnType) -> lis column_type = resolve_string_enum(column_type, DataDesignerColumnType) return [c for c in self._column_configs.values() if c.column_type != column_type] - def get_processor_configs(self) -> dict[BuildStage, list[ProcessorConfig]]: + def get_processor_configs(self) -> dict[BuildStage, list[ProcessorConfigT]]: """Get processor configuration objects. Returns: diff --git a/src/data_designer/config/data_designer_config.py b/src/data_designer/config/data_designer_config.py index 06f02995..d90deb41 100644 --- a/src/data_designer/config/data_designer_config.py +++ b/src/data_designer/config/data_designer_config.py @@ -11,7 +11,7 @@ from data_designer.config.base import ExportableConfigBase from data_designer.config.column_types import ColumnConfigT from data_designer.config.models import ModelConfig -from data_designer.config.processors import ProcessorConfig +from data_designer.config.processors import ProcessorConfigT from data_designer.config.sampler_constraints import ColumnConstraintT from data_designer.config.seed import SeedConfig @@ -37,4 +37,4 @@ class DataDesignerConfig(ExportableConfigBase): seed_config: Optional[SeedConfig] = None constraints: Optional[list[ColumnConstraintT]] = None profilers: Optional[list[ColumnProfilerConfigT]] = None - processors: Optional[list[ProcessorConfig]] = None + processors: Optional[list[Annotated[ProcessorConfigT, Field(discriminator="processor_type")]]] = None diff --git a/src/data_designer/config/processors.py b/src/data_designer/config/processors.py index 171e10e3..17d2ff7b 100644 --- a/src/data_designer/config/processors.py +++ b/src/data_designer/config/processors.py @@ -4,9 +4,10 @@ import json from abc import ABC from enum import Enum -from typing import Any, Literal +from typing import Any, Literal, Union from pydantic import Field, field_validator +from typing_extensions import TypeAlias from data_designer.config.base import ConfigBase from data_designer.config.dataset_builders import BuildStage @@ -47,6 +48,7 @@ class ProcessorConfig(ConfigBase, ABC): default=BuildStage.POST_BATCH, description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}", ) + processor_type: str @field_validator("build_stage") def validate_build_stage(cls, v: BuildStage) -> BuildStage: @@ -139,3 +141,9 @@ def validate_template(cls, v: dict[str, Any]) -> dict[str, Any]: if "not JSON serializable" in str(e): raise InvalidConfigError("Template must be JSON serializable") return v + + +ProcessorConfigT: TypeAlias = Union[ + DropColumnsProcessorConfig, + SchemaTransformProcessorConfig, +] diff --git a/src/data_designer/config/utils/validation.py b/src/data_designer/config/utils/validation.py index 7d3654ad..dc1ca2e3 100644 --- a/src/data_designer/config/utils/validation.py +++ b/src/data_designer/config/utils/validation.py @@ -16,7 +16,7 @@ from rich.panel import Panel from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_model_generated -from data_designer.config.processors import ProcessorConfig, ProcessorType +from data_designer.config.processors import ProcessorConfigT, ProcessorType from data_designer.config.utils.constants import RICH_CONSOLE_THEME from data_designer.config.utils.misc import ( can_run_data_designer_locally, @@ -57,7 +57,7 @@ def has_column(self) -> bool: def validate_data_designer_config( columns: list[ColumnConfigT], - processor_configs: list[ProcessorConfig], + processor_configs: list[ProcessorConfigT], allowed_references: list[str], ) -> list[Violation]: violations = [] @@ -273,7 +273,7 @@ def validate_columns_not_all_dropped( def validate_drop_columns_processor( columns: list[ColumnConfigT], - processor_configs: list[ProcessorConfig], + processor_configs: list[ProcessorConfigT], ) -> list[Violation]: all_column_names = {c.name for c in columns} for processor_config in processor_configs: @@ -294,7 +294,7 @@ def validate_drop_columns_processor( def validate_schema_transform_processor( columns: list[ColumnConfigT], - processor_configs: list[ProcessorConfig], + processor_configs: list[ProcessorConfigT], ) -> list[Violation]: violations = []