Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/data_designer/config/config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from data_designer.config.default_model_settings import get_default_model_configs
from data_designer.config.errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError
from data_designer.config.models import ModelConfig, load_model_configs
from data_designer.config.processors import ProcessorConfig, ProcessorType, get_processor_config_from_kwargs
from data_designer.config.processors import ProcessorConfigT, ProcessorType, get_processor_config_from_kwargs
from data_designer.config.sampler_constraints import (
ColumnConstraintT,
ColumnInequalityConstraint,
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(self, model_configs: Optional[Union[list[ModelConfig], str, Path]]
"""
self._column_configs = {}
self._model_configs = _load_model_configs(model_configs)
self._processor_configs: list[ProcessorConfig] = []
self._processor_configs: list[ProcessorConfigT] = []
self._seed_config: Optional[SeedConfig] = None
self._constraints: list[ColumnConstraintT] = []
self._profilers: list[ColumnProfilerConfigT] = []
Expand Down Expand Up @@ -298,7 +298,7 @@ def add_constraint(

def add_processor(
self,
processor_config: Optional[ProcessorConfig] = None,
processor_config: Optional[ProcessorConfigT] = None,
*,
processor_type: Optional[ProcessorType] = None,
**kwargs,
Expand Down Expand Up @@ -487,7 +487,7 @@ def get_columns_excluding_type(self, column_type: DataDesignerColumnType) -> lis
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
return [c for c in self._column_configs.values() if c.column_type != column_type]

def get_processor_configs(self) -> dict[BuildStage, list[ProcessorConfig]]:
def get_processor_configs(self) -> dict[BuildStage, list[ProcessorConfigT]]:
"""Get processor configuration objects.

Returns:
Expand Down
4 changes: 2 additions & 2 deletions src/data_designer/config/data_designer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from data_designer.config.base import ExportableConfigBase
from data_designer.config.column_types import ColumnConfigT
from data_designer.config.models import ModelConfig
from data_designer.config.processors import ProcessorConfig
from data_designer.config.processors import ProcessorConfigT
from data_designer.config.sampler_constraints import ColumnConstraintT
from data_designer.config.seed import SeedConfig

Expand All @@ -37,4 +37,4 @@ class DataDesignerConfig(ExportableConfigBase):
seed_config: Optional[SeedConfig] = None
constraints: Optional[list[ColumnConstraintT]] = None
profilers: Optional[list[ColumnProfilerConfigT]] = None
processors: Optional[list[ProcessorConfig]] = None
processors: Optional[list[Annotated[ProcessorConfigT, Field(discriminator="processor_type")]]] = None
10 changes: 9 additions & 1 deletion src/data_designer/config/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import json
from abc import ABC
from enum import Enum
from typing import Any, Literal
from typing import Any, Literal, Union

from pydantic import Field, field_validator
from typing_extensions import TypeAlias

from data_designer.config.base import ConfigBase
from data_designer.config.dataset_builders import BuildStage
Expand Down Expand Up @@ -47,6 +48,7 @@ class ProcessorConfig(ConfigBase, ABC):
default=BuildStage.POST_BATCH,
description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}",
)
processor_type: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we changing this to a string?

we actually did this for column types, which makes plugins easier to work with.


@field_validator("build_stage")
def validate_build_stage(cls, v: BuildStage) -> BuildStage:
Expand Down Expand Up @@ -139,3 +141,9 @@ def validate_template(cls, v: dict[str, Any]) -> dict[str, Any]:
if "not JSON serializable" in str(e):
raise InvalidConfigError("Template must be JSON serializable")
return v


ProcessorConfigT: TypeAlias = Union[
DropColumnsProcessorConfig,
SchemaTransformProcessorConfig,
]
8 changes: 4 additions & 4 deletions src/data_designer/config/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from rich.panel import Panel

from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_model_generated
from data_designer.config.processors import ProcessorConfig, ProcessorType
from data_designer.config.processors import ProcessorConfigT, ProcessorType
from data_designer.config.utils.constants import RICH_CONSOLE_THEME
from data_designer.config.utils.misc import (
can_run_data_designer_locally,
Expand Down Expand Up @@ -57,7 +57,7 @@ def has_column(self) -> bool:

def validate_data_designer_config(
columns: list[ColumnConfigT],
processor_configs: list[ProcessorConfig],
processor_configs: list[ProcessorConfigT],
allowed_references: list[str],
) -> list[Violation]:
violations = []
Expand Down Expand Up @@ -273,7 +273,7 @@ def validate_columns_not_all_dropped(

def validate_drop_columns_processor(
columns: list[ColumnConfigT],
processor_configs: list[ProcessorConfig],
processor_configs: list[ProcessorConfigT],
) -> list[Violation]:
all_column_names = {c.name for c in columns}
for processor_config in processor_configs:
Expand All @@ -294,7 +294,7 @@ def validate_drop_columns_processor(

def validate_schema_transform_processor(
columns: list[ColumnConfigT],
processor_configs: list[ProcessorConfig],
processor_configs: list[ProcessorConfigT],
) -> list[Violation]:
violations = []

Expand Down