Skip to content

Commit 3104ae1

Browse files
chore: add explicit discriminator field for processors (#145)
* chore: add explicit discriminator field for processors * using new type instead of base class everywhere * lint * using base instead of type in some places * processor_type needs to be str for correct ser/de
1 parent d50a8ae commit 3104ae1

File tree

4 files changed

+19
-11
lines changed

4 files changed

+19
-11
lines changed

src/data_designer/config/config_builder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from data_designer.config.default_model_settings import get_default_model_configs
3030
from data_designer.config.errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError
3131
from data_designer.config.models import ModelConfig, load_model_configs
32-
from data_designer.config.processors import ProcessorConfig, ProcessorType, get_processor_config_from_kwargs
32+
from data_designer.config.processors import ProcessorConfigT, ProcessorType, get_processor_config_from_kwargs
3333
from data_designer.config.sampler_constraints import (
3434
ColumnConstraintT,
3535
ColumnInequalityConstraint,
@@ -141,7 +141,7 @@ def __init__(self, model_configs: Optional[Union[list[ModelConfig], str, Path]]
141141
"""
142142
self._column_configs = {}
143143
self._model_configs = _load_model_configs(model_configs)
144-
self._processor_configs: list[ProcessorConfig] = []
144+
self._processor_configs: list[ProcessorConfigT] = []
145145
self._seed_config: Optional[SeedConfig] = None
146146
self._constraints: list[ColumnConstraintT] = []
147147
self._profilers: list[ColumnProfilerConfigT] = []
@@ -298,7 +298,7 @@ def add_constraint(
298298

299299
def add_processor(
300300
self,
301-
processor_config: Optional[ProcessorConfig] = None,
301+
processor_config: Optional[ProcessorConfigT] = None,
302302
*,
303303
processor_type: Optional[ProcessorType] = None,
304304
**kwargs,
@@ -487,7 +487,7 @@ def get_columns_excluding_type(self, column_type: DataDesignerColumnType) -> lis
487487
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
488488
return [c for c in self._column_configs.values() if c.column_type != column_type]
489489

490-
def get_processor_configs(self) -> dict[BuildStage, list[ProcessorConfig]]:
490+
def get_processor_configs(self) -> dict[BuildStage, list[ProcessorConfigT]]:
491491
"""Get processor configuration objects.
492492
493493
Returns:

src/data_designer/config/data_designer_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from data_designer.config.base import ExportableConfigBase
1212
from data_designer.config.column_types import ColumnConfigT
1313
from data_designer.config.models import ModelConfig
14-
from data_designer.config.processors import ProcessorConfig
14+
from data_designer.config.processors import ProcessorConfigT
1515
from data_designer.config.sampler_constraints import ColumnConstraintT
1616
from data_designer.config.seed import SeedConfig
1717

@@ -37,4 +37,4 @@ class DataDesignerConfig(ExportableConfigBase):
3737
seed_config: Optional[SeedConfig] = None
3838
constraints: Optional[list[ColumnConstraintT]] = None
3939
profilers: Optional[list[ColumnProfilerConfigT]] = None
40-
processors: Optional[list[ProcessorConfig]] = None
40+
processors: Optional[list[Annotated[ProcessorConfigT, Field(discriminator="processor_type")]]] = None

src/data_designer/config/processors.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import json
55
from abc import ABC
66
from enum import Enum
7-
from typing import Any, Literal
7+
from typing import Any, Literal, Union
88

99
from pydantic import Field, field_validator
10+
from typing_extensions import TypeAlias
1011

1112
from data_designer.config.base import ConfigBase
1213
from data_designer.config.dataset_builders import BuildStage
@@ -47,6 +48,7 @@ class ProcessorConfig(ConfigBase, ABC):
4748
default=BuildStage.POST_BATCH,
4849
description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}",
4950
)
51+
processor_type: str
5052

5153
@field_validator("build_stage")
5254
def validate_build_stage(cls, v: BuildStage) -> BuildStage:
@@ -139,3 +141,9 @@ def validate_template(cls, v: dict[str, Any]) -> dict[str, Any]:
139141
if "not JSON serializable" in str(e):
140142
raise InvalidConfigError("Template must be JSON serializable")
141143
return v
144+
145+
146+
ProcessorConfigT: TypeAlias = Union[
147+
DropColumnsProcessorConfig,
148+
SchemaTransformProcessorConfig,
149+
]

src/data_designer/config/utils/validation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from rich.panel import Panel
1717

1818
from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_model_generated
19-
from data_designer.config.processors import ProcessorConfig, ProcessorType
19+
from data_designer.config.processors import ProcessorConfigT, ProcessorType
2020
from data_designer.config.utils.constants import RICH_CONSOLE_THEME
2121
from data_designer.config.utils.misc import (
2222
can_run_data_designer_locally,
@@ -57,7 +57,7 @@ def has_column(self) -> bool:
5757

5858
def validate_data_designer_config(
5959
columns: list[ColumnConfigT],
60-
processor_configs: list[ProcessorConfig],
60+
processor_configs: list[ProcessorConfigT],
6161
allowed_references: list[str],
6262
) -> list[Violation]:
6363
violations = []
@@ -273,7 +273,7 @@ def validate_columns_not_all_dropped(
273273

274274
def validate_drop_columns_processor(
275275
columns: list[ColumnConfigT],
276-
processor_configs: list[ProcessorConfig],
276+
processor_configs: list[ProcessorConfigT],
277277
) -> list[Violation]:
278278
all_column_names = {c.name for c in columns}
279279
for processor_config in processor_configs:
@@ -294,7 +294,7 @@ def validate_drop_columns_processor(
294294

295295
def validate_schema_transform_processor(
296296
columns: list[ColumnConfigT],
297-
processor_configs: list[ProcessorConfig],
297+
processor_configs: list[ProcessorConfigT],
298298
) -> list[Violation]:
299299
violations = []
300300

0 commit comments

Comments
 (0)