-
Notifications
You must be signed in to change notification settings - Fork 51
chore: Make column_type a pydantic field rather than a property #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c677963
e874367
c143ef8
17b984f
ac4539b
ab20e4c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,7 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from enum import Enum | ||
| from abc import ABC | ||
| from typing import Literal, Optional, Type, Union | ||
|
|
||
| from pydantic import BaseModel, Field, model_validator | ||
|
|
@@ -15,56 +14,14 @@ | |
| from .utils.code_lang import CodeLang | ||
| from .utils.constants import REASONING_TRACE_COLUMN_POSTFIX | ||
| from .utils.misc import assert_valid_jinja2_template, get_prompt_template_keywords | ||
| from .utils.type_helpers import SAMPLER_PARAMS, resolve_string_enum | ||
| from .utils.type_helpers import SAMPLER_PARAMS, create_str_enum_from_discriminated_type_union, resolve_string_enum | ||
| from .validator_params import ValidatorParamsT, ValidatorType | ||
|
|
||
|
|
||
| class DataDesignerColumnType(str, Enum): | ||
| SAMPLER = "sampler" | ||
| LLM_TEXT = "llm-text" | ||
| LLM_CODE = "llm-code" | ||
| LLM_STRUCTURED = "llm-structured" | ||
| LLM_JUDGE = "llm-judge" | ||
| EXPRESSION = "expression" | ||
| VALIDATION = "validation" | ||
| SEED_DATASET = "seed-dataset" | ||
|
|
||
| @staticmethod | ||
| def get_display_order() -> list[Self]: | ||
| return [ | ||
| DataDesignerColumnType.SEED_DATASET, | ||
| DataDesignerColumnType.SAMPLER, | ||
| DataDesignerColumnType.LLM_TEXT, | ||
| DataDesignerColumnType.LLM_CODE, | ||
| DataDesignerColumnType.LLM_STRUCTURED, | ||
| DataDesignerColumnType.LLM_JUDGE, | ||
| DataDesignerColumnType.VALIDATION, | ||
| DataDesignerColumnType.EXPRESSION, | ||
| ] | ||
|
|
||
| @property | ||
| def has_prompt_templates(self) -> bool: | ||
| return self in [self.LLM_TEXT, self.LLM_CODE, self.LLM_STRUCTURED, self.LLM_JUDGE] | ||
|
|
||
| @property | ||
| def is_dag_column_type(self) -> bool: | ||
| return self in [ | ||
| self.EXPRESSION, | ||
| self.LLM_CODE, | ||
| self.LLM_JUDGE, | ||
| self.LLM_STRUCTURED, | ||
| self.LLM_TEXT, | ||
| self.VALIDATION, | ||
| ] | ||
|
|
||
|
|
||
| class SingleColumnConfig(ConfigBase, ABC): | ||
| name: str | ||
| drop: bool = False | ||
|
|
||
| @property | ||
| @abstractmethod | ||
| def column_type(self) -> DataDesignerColumnType: ... | ||
| column_type: str | ||
|
|
||
| @property | ||
| def required_columns(self) -> list[str]: | ||
|
|
@@ -80,21 +37,15 @@ class SamplerColumnConfig(SingleColumnConfig): | |
| params: SamplerParamsT | ||
| conditional_params: dict[str, SamplerParamsT] = {} | ||
| convert_to: Optional[str] = None | ||
|
|
||
| @property | ||
| def column_type(self) -> DataDesignerColumnType: | ||
| return DataDesignerColumnType.SAMPLER | ||
| column_type: Literal["sampler"] = "sampler" | ||
|
|
||
|
|
||
| class LLMTextColumnConfig(SingleColumnConfig): | ||
| prompt: str | ||
| model_alias: str | ||
| system_prompt: Optional[str] = None | ||
| multi_modal_context: Optional[list[ImageContext]] = None | ||
|
|
||
| @property | ||
| def column_type(self) -> DataDesignerColumnType: | ||
| return DataDesignerColumnType.LLM_TEXT | ||
| column_type: Literal["llm-text"] = "llm-text" | ||
|
|
||
| @property | ||
| def required_columns(self) -> list[str]: | ||
|
|
@@ -117,18 +68,12 @@ def assert_prompt_valid_jinja(self) -> Self: | |
|
|
||
| class LLMCodeColumnConfig(LLMTextColumnConfig): | ||
| code_lang: CodeLang | ||
|
|
||
| @property | ||
| def column_type(self) -> DataDesignerColumnType: | ||
| return DataDesignerColumnType.LLM_CODE | ||
| column_type: Literal["llm-code"] = "llm-code" | ||
|
|
||
|
|
||
| class LLMStructuredColumnConfig(LLMTextColumnConfig): | ||
| output_format: Union[dict, Type[BaseModel]] | ||
|
|
||
| @property | ||
| def column_type(self) -> DataDesignerColumnType: | ||
| return DataDesignerColumnType.LLM_STRUCTURED | ||
| column_type: Literal["llm-structured"] = "llm-structured" | ||
|
|
||
| @model_validator(mode="after") | ||
| def validate_output_format(self) -> Self: | ||
|
|
@@ -145,20 +90,14 @@ class Score(ConfigBase): | |
|
|
||
| class LLMJudgeColumnConfig(LLMTextColumnConfig): | ||
| scores: list[Score] = Field(..., min_length=1) | ||
|
|
||
| @property | ||
| def column_type(self) -> DataDesignerColumnType: | ||
| return DataDesignerColumnType.LLM_JUDGE | ||
| column_type: Literal["llm-judge"] = "llm-judge" | ||
|
|
||
|
|
||
| class ExpressionColumnConfig(SingleColumnConfig): | ||
| name: str | ||
| expr: str | ||
| dtype: Literal["int", "float", "str", "bool"] = "str" | ||
|
|
||
| @property | ||
| def column_type(self) -> DataDesignerColumnType: | ||
| return DataDesignerColumnType.EXPRESSION | ||
| column_type: Literal["expression"] = "expression" | ||
|
|
||
| @property | ||
| def required_columns(self) -> list[str]: | ||
|
|
@@ -168,7 +107,9 @@ def required_columns(self) -> list[str]: | |
| def assert_expression_valid_jinja(self) -> Self: | ||
| if not self.expr.strip(): | ||
| raise InvalidConfigError( | ||
| f"🛑 Expression column '{self.name}' has an empty or whitespace-only expression. Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') or remove this column if not needed." | ||
| f"🛑 Expression column '{self.name}' has an empty or whitespace-only expression. " | ||
| f"Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') " | ||
| "or remove this column if not needed." | ||
| ) | ||
| assert_valid_jinja2_template(self.expr) | ||
| return self | ||
|
|
@@ -179,20 +120,34 @@ class ValidationColumnConfig(SingleColumnConfig): | |
| validator_type: ValidatorType | ||
| validator_params: ValidatorParamsT | ||
| batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch") | ||
|
|
||
| @property | ||
| def column_type(self) -> DataDesignerColumnType: | ||
| return DataDesignerColumnType.VALIDATION | ||
| column_type: Literal["validation"] = "validation" | ||
|
|
||
| @property | ||
| def required_columns(self) -> list[str]: | ||
| return self.target_columns | ||
|
|
||
|
|
||
| class SeedDatasetColumnConfig(SingleColumnConfig): | ||
| @property | ||
| def column_type(self) -> DataDesignerColumnType: | ||
| return DataDesignerColumnType.SEED_DATASET | ||
| column_type: Literal["seed-dataset"] = "seed-dataset" | ||
|
|
||
|
|
||
| ColumnConfigT: TypeAlias = Union[ | ||
| ExpressionColumnConfig, | ||
| LLMCodeColumnConfig, | ||
| LLMJudgeColumnConfig, | ||
| LLMStructuredColumnConfig, | ||
| LLMTextColumnConfig, | ||
| SamplerColumnConfig, | ||
| SeedDatasetColumnConfig, | ||
| ValidationColumnConfig, | ||
| ] | ||
|
|
||
|
|
||
| DataDesignerColumnType = create_str_enum_from_discriminated_type_union( | ||
| enum_name="DataDesignerColumnType", | ||
| type_union=ColumnConfigT, | ||
| discriminator_field_name="column_type", | ||
| ) | ||
|
Comment on lines
+146
to
+150
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's where we dynamically create this StrEnum. |
||
|
|
||
|
|
||
| COLUMN_TYPE_EMOJI_MAP = { | ||
|
|
@@ -208,16 +163,28 @@ def column_type(self) -> DataDesignerColumnType: | |
| } | ||
|
|
||
|
|
||
| ColumnConfigT: TypeAlias = Union[ | ||
| ExpressionColumnConfig, | ||
| LLMCodeColumnConfig, | ||
| LLMJudgeColumnConfig, | ||
| LLMStructuredColumnConfig, | ||
| LLMTextColumnConfig, | ||
| SamplerColumnConfig, | ||
| SeedDatasetColumnConfig, | ||
| ValidationColumnConfig, | ||
| ] | ||
| def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumnType]) -> bool: | ||
| """Return True if the column type is used in the workflow execution DAG.""" | ||
| column_type = resolve_string_enum(column_type, DataDesignerColumnType) | ||
| return column_type in { | ||
| DataDesignerColumnType.EXPRESSION, | ||
| DataDesignerColumnType.LLM_CODE, | ||
| DataDesignerColumnType.LLM_JUDGE, | ||
| DataDesignerColumnType.LLM_STRUCTURED, | ||
| DataDesignerColumnType.LLM_TEXT, | ||
| DataDesignerColumnType.VALIDATION, | ||
|
Comment on lines
+170
to
+175
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does our type checker have an issue with this? Since
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're not running ty yet, but my IDE seems happy with all of it. If you can pull the branch to check your IDE settings, that might be helpful.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's change these to sets |
||
| } | ||
|
|
||
|
|
||
| def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType]) -> bool: | ||
| """Return True if the column type is an LLM-generated column.""" | ||
| column_type = resolve_string_enum(column_type, DataDesignerColumnType) | ||
| return column_type in { | ||
| DataDesignerColumnType.LLM_TEXT, | ||
| DataDesignerColumnType.LLM_CODE, | ||
| DataDesignerColumnType.LLM_STRUCTURED, | ||
| DataDesignerColumnType.LLM_JUDGE, | ||
| } | ||
|
|
||
|
|
||
| def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT: | ||
|
|
@@ -251,6 +218,20 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType | |
| raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover | ||
|
|
||
|
|
||
| def get_column_display_order() -> list[DataDesignerColumnType]: | ||
| """Return the preferred display order of the column types.""" | ||
| return [ | ||
| DataDesignerColumnType.SEED_DATASET, | ||
| DataDesignerColumnType.SAMPLER, | ||
| DataDesignerColumnType.LLM_TEXT, | ||
| DataDesignerColumnType.LLM_CODE, | ||
| DataDesignerColumnType.LLM_STRUCTURED, | ||
| DataDesignerColumnType.LLM_JUDGE, | ||
| DataDesignerColumnType.VALIDATION, | ||
| DataDesignerColumnType.EXPRESSION, | ||
| ] | ||
|
|
||
|
|
||
| def _resolve_sampler_kwargs(name: str, kwargs: dict) -> dict: | ||
| if "sampler_type" not in kwargs: | ||
| raise InvalidConfigError(f"🛑 `sampler_type` is required for sampler column '{name}'.") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Optional | ||
| from typing import Annotated, Optional | ||
|
|
||
| from pydantic import Field | ||
|
|
||
|
|
@@ -32,7 +32,7 @@ class DataDesignerConfig(ExportableConfigBase): | |
| profilers: Optional list of column profilers for analyzing generated data characteristics. | ||
| """ | ||
|
|
||
| columns: list[ColumnConfigT] = Field(min_length=1) | ||
| columns: list[Annotated[ColumnConfigT, Field(discriminator="column_type")]] = Field(min_length=1) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now we can do this |
||
| model_configs: Optional[list[ModelConfig]] = None | ||
| seed_config: Optional[SeedConfig] = None | ||
| constraints: Optional[list[ColumnConstraintT]] = None | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.