diff --git a/.github/workflows/dco-assistant.yml b/.github/workflows/dco-assistant.yml index 91d5764f..5b1d97bf 100644 --- a/.github/workflows/dco-assistant.yml +++ b/.github/workflows/dco-assistant.yml @@ -27,7 +27,7 @@ jobs: steps: - name: "DCO Assistant" if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the Contributor Agreement including DCO and I hereby sign the Contributor Agreement and DCO') || github.event_name == 'pull_request_target' - uses: contributor-assistant/github-action@v2.6.1 + uses: contributor-assistant/github-action@ca4a40a7d1004f18d9960b404b97e5f30a505a08 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PERSONAL_ACCESS_TOKEN: ${{ secrets.DCO_ASSISTANT_TOKEN }} diff --git a/src/data_designer/config/analysis/dataset_profiler.py b/src/data_designer/config/analysis/dataset_profiler.py index f9e0e168..aa2b638f 100644 --- a/src/data_designer/config/analysis/dataset_profiler.py +++ b/src/data_designer/config/analysis/dataset_profiler.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field, field_validator -from ..columns import DataDesignerColumnType +from ..columns import DataDesignerColumnType, get_column_display_order from ..utils.constants import EPSILON from ..utils.numerical_helpers import prepare_number_for_reporting from .column_profilers import ColumnProfilerResultsT @@ -32,7 +32,7 @@ def percent_complete(self) -> float: @cached_property def column_types(self) -> list[str]: - display_order = DataDesignerColumnType.get_display_order() + display_order = get_column_display_order() return sorted( list(set([c.column_type for c in self.column_statistics])), key=lambda x: display_order.index(x) if x in display_order else len(display_order), diff --git a/src/data_designer/config/analysis/utils/reporting.py b/src/data_designer/config/analysis/utils/reporting.py index ca62a0cd..e4df4190 100644 --- a/src/data_designer/config/analysis/utils/reporting.py +++ b/src/data_designer/config/analysis/utils/reporting.py @@ -15,7 +15,7 @@ from rich.text import Text from ...analysis.column_statistics import CategoricalHistogramData -from ...columns import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType +from ...columns import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType, get_column_display_order from ...utils.visualization import ( ColorPalette, convert_to_row_element, @@ -44,7 +44,7 @@ class ReportSection(str, Enum): DEFAULT_INCLUDE_SECTIONS = [ ReportSection.OVERVIEW, ReportSection.COLUMN_PROFILERS, -] + DataDesignerColumnType.get_display_order() +] + get_column_display_order() def generate_analysis_report( diff --git a/src/data_designer/config/columns.py b/src/data_designer/config/columns.py index c2449499..8886cb09 100644 --- a/src/data_designer/config/columns.py +++ b/src/data_designer/config/columns.py @@ -1,8 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from abc import ABC, abstractmethod -from enum import Enum +from abc import ABC from typing import Literal, Optional, Type, Union from pydantic import BaseModel, Field, model_validator @@ -15,56 +14,14 @@ from .utils.code_lang import CodeLang from .utils.constants import REASONING_TRACE_COLUMN_POSTFIX from .utils.misc import assert_valid_jinja2_template, get_prompt_template_keywords -from .utils.type_helpers import SAMPLER_PARAMS, resolve_string_enum +from .utils.type_helpers import SAMPLER_PARAMS, create_str_enum_from_discriminated_type_union, resolve_string_enum from .validator_params import ValidatorParamsT, ValidatorType -class DataDesignerColumnType(str, Enum): - SAMPLER = "sampler" - LLM_TEXT = "llm-text" - LLM_CODE = "llm-code" - LLM_STRUCTURED = "llm-structured" - LLM_JUDGE = "llm-judge" - EXPRESSION = "expression" - VALIDATION = "validation" - SEED_DATASET = "seed-dataset" - - @staticmethod - def get_display_order() -> list[Self]: - return [ - DataDesignerColumnType.SEED_DATASET, - DataDesignerColumnType.SAMPLER, - DataDesignerColumnType.LLM_TEXT, - DataDesignerColumnType.LLM_CODE, - DataDesignerColumnType.LLM_STRUCTURED, - DataDesignerColumnType.LLM_JUDGE, - DataDesignerColumnType.VALIDATION, - DataDesignerColumnType.EXPRESSION, - ] - - @property - def has_prompt_templates(self) -> bool: - return self in [self.LLM_TEXT, self.LLM_CODE, self.LLM_STRUCTURED, self.LLM_JUDGE] - - @property - def is_dag_column_type(self) -> bool: - return self in [ - self.EXPRESSION, - self.LLM_CODE, - self.LLM_JUDGE, - self.LLM_STRUCTURED, - self.LLM_TEXT, - self.VALIDATION, - ] - - class SingleColumnConfig(ConfigBase, ABC): name: str drop: bool = False - - @property - @abstractmethod - def column_type(self) -> DataDesignerColumnType: ... + column_type: str @property def required_columns(self) -> list[str]: @@ -80,10 +37,7 @@ class SamplerColumnConfig(SingleColumnConfig): params: SamplerParamsT conditional_params: dict[str, SamplerParamsT] = {} convert_to: Optional[str] = None - - @property - def column_type(self) -> DataDesignerColumnType: - return DataDesignerColumnType.SAMPLER + column_type: Literal["sampler"] = "sampler" class LLMTextColumnConfig(SingleColumnConfig): @@ -91,10 +45,7 @@ class LLMTextColumnConfig(SingleColumnConfig): model_alias: str system_prompt: Optional[str] = None multi_modal_context: Optional[list[ImageContext]] = None - - @property - def column_type(self) -> DataDesignerColumnType: - return DataDesignerColumnType.LLM_TEXT + column_type: Literal["llm-text"] = "llm-text" @property def required_columns(self) -> list[str]: @@ -117,18 +68,12 @@ def assert_prompt_valid_jinja(self) -> Self: class LLMCodeColumnConfig(LLMTextColumnConfig): code_lang: CodeLang - - @property - def column_type(self) -> DataDesignerColumnType: - return DataDesignerColumnType.LLM_CODE + column_type: Literal["llm-code"] = "llm-code" class LLMStructuredColumnConfig(LLMTextColumnConfig): output_format: Union[dict, Type[BaseModel]] - - @property - def column_type(self) -> DataDesignerColumnType: - return DataDesignerColumnType.LLM_STRUCTURED + column_type: Literal["llm-structured"] = "llm-structured" @model_validator(mode="after") def validate_output_format(self) -> Self: @@ -145,20 +90,14 @@ class Score(ConfigBase): class LLMJudgeColumnConfig(LLMTextColumnConfig): scores: list[Score] = Field(..., min_length=1) - - @property - def column_type(self) -> DataDesignerColumnType: - return DataDesignerColumnType.LLM_JUDGE + column_type: Literal["llm-judge"] = "llm-judge" class ExpressionColumnConfig(SingleColumnConfig): name: str expr: str dtype: Literal["int", "float", "str", "bool"] = "str" - - @property - def column_type(self) -> DataDesignerColumnType: - return DataDesignerColumnType.EXPRESSION + column_type: Literal["expression"] = "expression" @property def required_columns(self) -> list[str]: @@ -168,7 +107,9 @@ def required_columns(self) -> list[str]: def assert_expression_valid_jinja(self) -> Self: if not self.expr.strip(): raise InvalidConfigError( - f"🛑 Expression column '{self.name}' has an empty or whitespace-only expression. Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') or remove this column if not needed." + f"🛑 Expression column '{self.name}' has an empty or whitespace-only expression. " + f"Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') " + "or remove this column if not needed." ) assert_valid_jinja2_template(self.expr) return self @@ -179,10 +120,7 @@ class ValidationColumnConfig(SingleColumnConfig): validator_type: ValidatorType validator_params: ValidatorParamsT batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch") - - @property - def column_type(self) -> DataDesignerColumnType: - return DataDesignerColumnType.VALIDATION + column_type: Literal["validation"] = "validation" @property def required_columns(self) -> list[str]: @@ -190,9 +128,26 @@ def required_columns(self) -> list[str]: class SeedDatasetColumnConfig(SingleColumnConfig): - @property - def column_type(self) -> DataDesignerColumnType: - return DataDesignerColumnType.SEED_DATASET + column_type: Literal["seed-dataset"] = "seed-dataset" + + +ColumnConfigT: TypeAlias = Union[ + ExpressionColumnConfig, + LLMCodeColumnConfig, + LLMJudgeColumnConfig, + LLMStructuredColumnConfig, + LLMTextColumnConfig, + SamplerColumnConfig, + SeedDatasetColumnConfig, + ValidationColumnConfig, +] + + +DataDesignerColumnType = create_str_enum_from_discriminated_type_union( + enum_name="DataDesignerColumnType", + type_union=ColumnConfigT, + discriminator_field_name="column_type", +) COLUMN_TYPE_EMOJI_MAP = { @@ -208,16 +163,28 @@ def column_type(self) -> DataDesignerColumnType: } -ColumnConfigT: TypeAlias = Union[ - ExpressionColumnConfig, - LLMCodeColumnConfig, - LLMJudgeColumnConfig, - LLMStructuredColumnConfig, - LLMTextColumnConfig, - SamplerColumnConfig, - SeedDatasetColumnConfig, - ValidationColumnConfig, -] +def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumnType]) -> bool: + """Return True if the column type is used in the workflow execution DAG.""" + column_type = resolve_string_enum(column_type, DataDesignerColumnType) + return column_type in { + DataDesignerColumnType.EXPRESSION, + DataDesignerColumnType.LLM_CODE, + DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.LLM_STRUCTURED, + DataDesignerColumnType.LLM_TEXT, + DataDesignerColumnType.VALIDATION, + } + + +def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType]) -> bool: + """Return True if the column type is an LLM-generated column.""" + column_type = resolve_string_enum(column_type, DataDesignerColumnType) + return column_type in { + DataDesignerColumnType.LLM_TEXT, + DataDesignerColumnType.LLM_CODE, + DataDesignerColumnType.LLM_STRUCTURED, + DataDesignerColumnType.LLM_JUDGE, + } def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT: @@ -251,6 +218,20 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover +def get_column_display_order() -> list[DataDesignerColumnType]: + """Return the preferred display order of the column types.""" + return [ + DataDesignerColumnType.SEED_DATASET, + DataDesignerColumnType.SAMPLER, + DataDesignerColumnType.LLM_TEXT, + DataDesignerColumnType.LLM_CODE, + DataDesignerColumnType.LLM_STRUCTURED, + DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.VALIDATION, + DataDesignerColumnType.EXPRESSION, + ] + + def _resolve_sampler_kwargs(name: str, kwargs: dict) -> dict: if "sampler_type" not in kwargs: raise InvalidConfigError(f"🛑 `sampler_type` is required for sampler column '{name}'.") diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index 8bfe95e6..78cfe724 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -15,7 +15,13 @@ from .analysis.column_profilers import ColumnProfilerConfigT from .base import ExportableConfigBase -from .columns import ColumnConfigT, DataDesignerColumnType, SeedDatasetColumnConfig, get_column_config_from_kwargs +from .columns import ( + ColumnConfigT, + DataDesignerColumnType, + SeedDatasetColumnConfig, + column_type_is_llm_generated, + get_column_config_from_kwargs, +) from .data_designer_config import DataDesignerConfig from .dataset_builders import BuildStage from .datastore import DatastoreSettings, fetch_seed_dataset_column_names @@ -449,7 +455,7 @@ def get_llm_gen_columns(self) -> list[ColumnConfigT]: Returns: A list of column configurations that use LLM generation. """ - return [c for c in self._column_configs.values() if c.column_type.has_prompt_templates] + return [c for c in self._column_configs.values() if column_type_is_llm_generated(c.column_type)] def get_columns_of_type(self, column_type: DataDesignerColumnType) -> list[ColumnConfigT]: """Get all column configurations of the specified type. diff --git a/src/data_designer/config/data_designer_config.py b/src/data_designer/config/data_designer_config.py index ba717fd8..24d791b8 100644 --- a/src/data_designer/config/data_designer_config.py +++ b/src/data_designer/config/data_designer_config.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import Optional +from typing import Annotated, Optional from pydantic import Field @@ -32,7 +32,7 @@ class DataDesignerConfig(ExportableConfigBase): profilers: Optional list of column profilers for analyzing generated data characteristics. """ - columns: list[ColumnConfigT] = Field(min_length=1) + columns: list[Annotated[ColumnConfigT, Field(discriminator="column_type")]] = Field(min_length=1) model_configs: Optional[list[ModelConfig]] = None seed_config: Optional[SeedConfig] = None constraints: Optional[list[ColumnConstraintT]] = None diff --git a/src/data_designer/config/utils/errors.py b/src/data_designer/config/utils/errors.py index 2caf45a1..3917cd04 100644 --- a/src/data_designer/config/utils/errors.py +++ b/src/data_designer/config/utils/errors.py @@ -10,4 +10,10 @@ class UserJinjaTemplateSyntaxError(DataDesignerError): ... class InvalidEnumValueError(DataDesignerError): ... +class InvalidTypeUnionError(DataDesignerError): ... + + +class InvalidDiscriminatorFieldError(DataDesignerError): ... + + class DatasetSampleDisplayError(DataDesignerError): ... diff --git a/src/data_designer/config/utils/type_helpers.py b/src/data_designer/config/utils/type_helpers.py index 02b17bb6..ba6181be 100644 --- a/src/data_designer/config/utils/type_helpers.py +++ b/src/data_designer/config/utils/type_helpers.py @@ -3,12 +3,53 @@ from enum import Enum import inspect -from typing import Any, Type +from typing import Any, Literal, Type, Union, get_args, get_origin from pydantic import BaseModel from .. import sampler_params -from .errors import InvalidEnumValueError +from .errors import InvalidDiscriminatorFieldError, InvalidEnumValueError, InvalidTypeUnionError + + +class StrEnum(str, Enum): + pass + + +def create_str_enum_from_discriminated_type_union( + enum_name: str, + type_union: Type[Union[BaseModel, ...]], + discriminator_field_name: str, +) -> StrEnum: + """Create a string enum from a type union. + + The type union is assumed to be a union of configs (Pydantic models) that have a discriminator field, + which must be a Literal string type - e.g., Literal["expression"]. + + Args: + enum_name: Name of the StrEnum. + type_union: Type union of configs (Pydantic models). + discriminator_field_name: Name of the discriminator field. + + Returns: + StrEnum with values being the discriminator field values of the configs in the type union. + + Example: + DataDesignerColumnType = create_str_enum_from_discriminated_type_union( + enum_name="DataDesignerColumnType", + type_union=ColumnConfigT, + discriminator_field_name="column_type", + ) + """ + discriminator_field_values = [] + for model in type_union.__args__: + if not issubclass(model, BaseModel): + raise InvalidTypeUnionError(f"🛑 {model} must be a subclass of pydantic.BaseModel.") + if discriminator_field_name not in model.model_fields: + raise InvalidDiscriminatorFieldError(f"🛑 '{discriminator_field_name}' is not a field of {model}.") + if get_origin(model.model_fields[discriminator_field_name].annotation) is not Literal: + raise InvalidDiscriminatorFieldError(f"🛑 '{discriminator_field_name}' must be a Literal type.") + discriminator_field_values.extend(get_args(model.model_fields[discriminator_field_name].annotation)) + return StrEnum(enum_name, {v.replace("-", "_").upper(): v for v in set(discriminator_field_values)}) def get_sampler_params() -> dict[str, Type[BaseModel]]: diff --git a/src/data_designer/config/utils/validation.py b/src/data_designer/config/utils/validation.py index a3864d31..428c575c 100644 --- a/src/data_designer/config/utils/validation.py +++ b/src/data_designer/config/utils/validation.py @@ -15,7 +15,7 @@ from rich.padding import Padding from rich.panel import Panel -from ..columns import ColumnConfigT, DataDesignerColumnType +from ..columns import ColumnConfigT, DataDesignerColumnType, column_type_is_llm_generated from ..processors import ProcessorConfig, ProcessorType from ..validator_params import ValidatorType from .constants import RICH_CONSOLE_THEME @@ -115,7 +115,7 @@ def validate_prompt_templates( ) -> list[Violation]: env = ImmutableSandboxedEnvironment() - columns_with_prompts = [c for c in columns if c.column_type.has_prompt_templates] + columns_with_prompts = [c for c in columns if column_type_is_llm_generated(c.column_type)] violations = [] for column in columns_with_prompts: diff --git a/src/data_designer/engine/analysis/dataset_profiler.py b/src/data_designer/engine/analysis/dataset_profiler.py index f9fe4888..3dd39ed0 100644 --- a/src/data_designer/engine/analysis/dataset_profiler.py +++ b/src/data_designer/engine/analysis/dataset_profiler.py @@ -83,7 +83,7 @@ def profile_dataset( profiler = self._create_column_profiler(profiler_config) applicable_column_types = profiler.metadata().applicable_column_types for c in self.config.column_configs: - if c.column_type.value in applicable_column_types: + if c.column_type in applicable_column_types: params = ColumnConfigWithDataFrame(column_config=c, df=dataset) column_profiles.append(profiler.profile(params)) if len(column_profiles) == 0: diff --git a/src/data_designer/engine/analysis/reporting/report.py b/src/data_designer/engine/analysis/reporting/report.py deleted file mode 100644 index f2f8beb8..00000000 --- a/src/data_designer/engine/analysis/reporting/report.py +++ /dev/null @@ -1,4 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Report generation code goes here.""" diff --git a/src/data_designer/engine/column_generators/generators/llm_generators.py b/src/data_designer/engine/column_generators/generators/llm_generators.py index 76ea97d3..14324f36 100644 --- a/src/data_designer/engine/column_generators/generators/llm_generators.py +++ b/src/data_designer/engine/column_generators/generators/llm_generators.py @@ -108,7 +108,7 @@ def generate(self, data: dict) -> dict: def log_pre_generation(self) -> None: emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type] - logger.info(f"{emoji} Preparing {self.config.column_type.value} column generation") + logger.info(f"{emoji} Preparing {self.config.column_type} column generation") logger.info(f" |-- column name: {self.config.name!r}") logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}") if self.model_config.provider is None: diff --git a/src/data_designer/engine/configurable_task.py b/src/data_designer/engine/configurable_task.py index 0c3f10a5..ee2dff46 100644 --- a/src/data_designer/engine/configurable_task.py +++ b/src/data_designer/engine/configurable_task.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Generic, Type, TypeVar +from typing import Generic, Type, TypeVar, get_origin import pandas as pd @@ -32,8 +32,11 @@ def __init__(self, config: TaskConfigT, *, resource_provider: ResourceProvider | @classmethod def get_config_type(cls) -> Type[TaskConfigT]: for base in cls.__orig_bases__: - if hasattr(base, "__args__") and len(base.__args__) == 1 and issubclass(base.__args__[0], ConfigBase): - return base.__args__[0] + if hasattr(base, "__args__") and len(base.__args__) == 1: + arg = base.__args__[0] + origin = get_origin(arg) or arg + if isinstance(origin, type) and issubclass(origin, ConfigBase): + return base.__args__[0] raise TypeError( f"Could not determine config type for `{cls.__name__}`. Please ensure that the " "`ConfigurableTask` is defined with a generic type argument, where the type argument " diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py index f16a5a96..26e794ca 100644 --- a/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -10,7 +10,7 @@ import pandas as pd -from data_designer.config.columns import ColumnConfigT +from data_designer.config.columns import ColumnConfigT, column_type_is_llm_generated from data_designer.config.dataset_builders import BuildStage from data_designer.config.processors import ( DropColumnsProcessorConfig, @@ -64,10 +64,10 @@ def artifact_storage(self) -> ArtifactStorage: def single_column_configs(self) -> list[ColumnConfigT]: configs = [] for config in self._column_configs: - if isinstance(config, ColumnConfigT): - configs.append(config) - elif isinstance(config, MultiColumnConfig): + if isinstance(config, MultiColumnConfig): configs.extend(config.columns) + else: + configs.append(config) return configs def build( @@ -171,7 +171,7 @@ def _run_full_column_generator(self, generator: ColumnGenerator) -> None: self.batch_manager.update_records(df.to_dict(orient="records")) def _run_model_health_check_if_needed(self) -> bool: - if any(config.column_type.has_prompt_templates for config in self.single_column_configs): + if any(column_type_is_llm_generated(config.column_type) for config in self.single_column_configs): self._resource_provider.model_registry.run_health_check() def _fan_out_with_threads(self, generator: WithLLMGeneration, max_workers: int) -> None: @@ -182,7 +182,7 @@ def _fan_out_with_threads(self, generator: WithLLMGeneration, max_workers: int) ) logger.info( - f"🐙 Processing {generator.config.column_type.value} column '{generator.config.name}' " + f"🐙 Processing {generator.config.column_type} column '{generator.config.name}' " f"with {max_workers} concurrent workers" ) with ConcurrentThreadExecutor( diff --git a/src/data_designer/engine/dataset_builders/utils/dag.py b/src/data_designer/engine/dataset_builders/utils/dag.py index 65196b7c..e6d9da14 100644 --- a/src/data_designer/engine/dataset_builders/utils/dag.py +++ b/src/data_designer/engine/dataset_builders/utils/dag.py @@ -5,7 +5,7 @@ import networkx as nx -from data_designer.config.columns import ColumnConfigT +from data_designer.config.columns import ColumnConfigT, column_type_used_in_execution_dag from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError logger = logging.getLogger(__name__) @@ -14,8 +14,12 @@ def topologically_sort_column_configs(column_configs: list[ColumnConfigT]) -> list[ColumnConfigT]: dag = nx.DiGraph() - non_dag_column_config_list = [col for col in column_configs if not col.column_type.is_dag_column_type] - dag_column_config_dict = {col.name: col for col in column_configs if col.column_type.is_dag_column_type} + non_dag_column_config_list = [ + col for col in column_configs if not column_type_used_in_execution_dag(col.column_type) + ] + dag_column_config_dict = { + col.name: col for col in column_configs if column_type_used_in_execution_dag(col.column_type) + } if len(dag_column_config_dict) == 0: return non_dag_column_config_list diff --git a/tests/config/test_columns.py b/tests/config/test_columns.py index 0bfa725d..df2d0668 100644 --- a/tests/config/test_columns.py +++ b/tests/config/test_columns.py @@ -15,7 +15,10 @@ Score, SeedDatasetColumnConfig, ValidationColumnConfig, + column_type_is_llm_generated, + column_type_used_in_execution_dag, get_column_config_from_kwargs, + get_column_display_order, ) from data_designer.config.errors import InvalidConfigError from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams @@ -29,7 +32,7 @@ def test_data_designer_column_type_get_display_order(): - assert DataDesignerColumnType.get_display_order() == [ + assert get_column_display_order() == [ DataDesignerColumnType.SEED_DATASET, DataDesignerColumnType.SAMPLER, DataDesignerColumnType.LLM_TEXT, @@ -41,26 +44,26 @@ def test_data_designer_column_type_get_display_order(): ] -def test_data_designer_column_type_has_prompt_templates(): - assert DataDesignerColumnType.LLM_TEXT.has_prompt_templates - assert DataDesignerColumnType.LLM_CODE.has_prompt_templates - assert DataDesignerColumnType.LLM_STRUCTURED.has_prompt_templates - assert DataDesignerColumnType.LLM_JUDGE.has_prompt_templates - assert not DataDesignerColumnType.SAMPLER.has_prompt_templates - assert not DataDesignerColumnType.VALIDATION.has_prompt_templates - assert not DataDesignerColumnType.EXPRESSION.has_prompt_templates - assert not DataDesignerColumnType.SEED_DATASET.has_prompt_templates - - -def test_data_designer_column_type_is_dag_column_type(): - assert DataDesignerColumnType.EXPRESSION.is_dag_column_type - assert DataDesignerColumnType.LLM_CODE.is_dag_column_type - assert DataDesignerColumnType.LLM_JUDGE.is_dag_column_type - assert DataDesignerColumnType.LLM_STRUCTURED.is_dag_column_type - assert DataDesignerColumnType.LLM_TEXT.is_dag_column_type - assert DataDesignerColumnType.VALIDATION.is_dag_column_type - assert not DataDesignerColumnType.SAMPLER.is_dag_column_type - assert not DataDesignerColumnType.SEED_DATASET.is_dag_column_type +def test_data_designer_column_type_is_llm_generated(): + assert column_type_is_llm_generated(DataDesignerColumnType.LLM_TEXT) + assert column_type_is_llm_generated(DataDesignerColumnType.LLM_CODE) + assert column_type_is_llm_generated(DataDesignerColumnType.LLM_STRUCTURED) + assert column_type_is_llm_generated(DataDesignerColumnType.LLM_JUDGE) + assert not column_type_is_llm_generated(DataDesignerColumnType.SAMPLER) + assert not column_type_is_llm_generated(DataDesignerColumnType.VALIDATION) + assert not column_type_is_llm_generated(DataDesignerColumnType.EXPRESSION) + assert not column_type_is_llm_generated(DataDesignerColumnType.SEED_DATASET) + + +def test_data_designer_column_type_is_in_dag(): + assert column_type_used_in_execution_dag(DataDesignerColumnType.EXPRESSION) + assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_CODE) + assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_JUDGE) + assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_STRUCTURED) + assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_TEXT) + assert column_type_used_in_execution_dag(DataDesignerColumnType.VALIDATION) + assert not column_type_used_in_execution_dag(DataDesignerColumnType.SAMPLER) + assert not column_type_used_in_execution_dag(DataDesignerColumnType.SEED_DATASET) def test_sampler_column_config(): diff --git a/tests/config/utils/test_type_helpers.py b/tests/config/utils/test_type_helpers.py index 36a2c4ab..83f0ffeb 100644 --- a/tests/config/utils/test_type_helpers.py +++ b/tests/config/utils/test_type_helpers.py @@ -2,17 +2,134 @@ # SPDX-License-Identifier: Apache-2.0 from enum import Enum +from typing import Literal, Union +from pydantic import BaseModel import pytest -from data_designer.config.utils.errors import InvalidEnumValueError -from data_designer.config.utils.type_helpers import SAMPLER_PARAMS, get_sampler_params, resolve_string_enum +from data_designer.config.utils.errors import ( + InvalidDiscriminatorFieldError, + InvalidEnumValueError, + InvalidTypeUnionError, +) +from data_designer.config.utils.type_helpers import ( + SAMPLER_PARAMS, + create_str_enum_from_discriminated_type_union, + get_sampler_params, + resolve_string_enum, +) class StubTestEnum(str, Enum): TEST = "test" +class StubModelA(BaseModel): + column_type: Literal["type-a", "type-a-alt"] = "type-a" + name: str + + +class StubModelB(BaseModel): + column_type: Literal["type-b"] = "type-b" + value: int + + +class StubModelC(BaseModel): + column_type: Literal["type-c-with-dashes"] = "type-c-with-dashes" + data: str + + +class StubModelWithoutDiscriminator(BaseModel): + name: str + value: int + + +class NotAModel: + column_type: str = "not-a-model" + + +def test_create_str_enum_from_type_union_basic() -> None: + type_union = Union[StubModelA, StubModelB] + result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type") + + assert issubclass(result, Enum) + assert issubclass(result, str) + assert hasattr(result, "TYPE_A") + assert hasattr(result, "TYPE_A_ALT") + assert hasattr(result, "TYPE_B") + assert result.TYPE_A.value == "type-a" + assert result.TYPE_A_ALT.value == "type-a-alt" + assert result.TYPE_B.value == "type-b" + assert len(result) == 3 + + +def test_create_str_enum_from_type_union_with_dashes() -> None: + type_union = Union[StubModelC, StubModelA] + result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type") + + assert hasattr(result, "TYPE_C_WITH_DASHES") + assert result.TYPE_C_WITH_DASHES.value == "type-c-with-dashes" + + +def test_create_str_enum_from_type_union_multiple_models() -> None: + type_union = Union[StubModelA, StubModelB, StubModelC] + result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type") + + assert len(result) == 4 + assert hasattr(result, "TYPE_A") + assert hasattr(result, "TYPE_A_ALT") + assert hasattr(result, "TYPE_B") + assert hasattr(result, "TYPE_C_WITH_DASHES") + + +def test_create_str_enum_from_type_union_duplicate_values() -> None: + class StubModelD(BaseModel): + column_type: Literal["type-a"] = "type-a" + extra: str + + type_union = Union[StubModelA, StubModelD] + result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type") + + assert len(result) == 2 + assert hasattr(result, "TYPE_A") + assert hasattr(result, "TYPE_A_ALT") + + +def test_create_str_enum_from_type_union_not_pydantic_model() -> None: + type_union = Union[StubModelA, NotAModel] + + with pytest.raises(InvalidTypeUnionError, match="must be a subclass of pydantic.BaseModel"): + create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type") + + +def test_create_str_enum_from_type_union_invalid_discriminator_field() -> None: + type_union = Union[StubModelA, StubModelWithoutDiscriminator] + + with pytest.raises(InvalidDiscriminatorFieldError, match="'column_type' is not a field of"): + create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type") + + with pytest.raises(InvalidDiscriminatorFieldError, match="'name' must be a Literal type"): + create_str_enum_from_discriminated_type_union("TestEnum", type_union, "name") + + +def test_create_str_enum_from_type_union_custom_discriminator_name() -> None: + class StubModelE(BaseModel): + type_field: Literal["custom-type"] = "custom-type" + name: str + + class StubModelF(BaseModel): + type_field: Literal["another-type"] = "another-type" + value: int + + type_union = Union[StubModelE, StubModelF] + result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "type_field") + + assert hasattr(result, "CUSTOM_TYPE") + assert result.CUSTOM_TYPE.value == "custom-type" + assert hasattr(result, "ANOTHER_TYPE") + assert result.ANOTHER_TYPE.value == "another-type" + + def test_get_sampler_params(): expected_sampler_keys = { "bernoulli", diff --git a/tests/conftest.py b/tests/conftest.py index a5b93e8b..62ecc60a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,31 +45,38 @@ def stub_data_designer_config_str() -> str: columns: - name: code_id sampler_type: uuid + column_type: sampler params: prefix: code_ short_form: true uppercase: true - name: age sampler_type: uniform + column_type: sampler params: low: 35 high: 88 - name: domain sampler_type: category + column_type: sampler params: values: [Healthcare, Finance, Education, Government] - name: topic sampler_type: category + column_type: sampler params: values: [Web Development, Data Science, Machine Learning, Cloud Computing] - name: text + column_type: llm-text prompt: Write a description of python code in topic {topic} and domain {domain} model_alias: my_own_code_model - name: code + column_type: llm-code prompt: Write Python code that will be paired with the following prompt {text} model_alias: my_own_code_model code_lang: python - name: code_validation_result + column_type: validation target_columns: - code validator_type: code @@ -77,6 +84,7 @@ def stub_data_designer_config_str() -> str: code_lang: python - name: code_judge_result model_alias: my_own_code_model + column_type: llm-judge prompt: You are an expert in Python programming and make appropriate judgement on the quality of the code. scores: - name: Pythonic diff --git a/tests/engine/column_generators/generators/test_seed_dataset.py b/tests/engine/column_generators/generators/test_seed_dataset.py index 487d143f..ebb9c72a 100644 --- a/tests/engine/column_generators/generators/test_seed_dataset.py +++ b/tests/engine/column_generators/generators/test_seed_dataset.py @@ -121,9 +121,9 @@ def test_seed_dataset_column_generator_config_structure(): assert config.sampling_strategy == SamplingStrategy.SHUFFLE assert len(config.columns) == 2 assert config.columns[0].name == "col1" - assert config.columns[0].column_type.value == "seed-dataset" + assert config.columns[0].column_type == "seed-dataset" assert config.columns[1].name == "col2" - assert config.columns[1].column_type.value == "seed-dataset" + assert config.columns[1].column_type == "seed-dataset" assert config.selection_strategy is None # Test PartitionBlock selection strategy