Skip to content

Commit 728e319

Browse files
authored
chore: Make column_type a pydantic field rather than a property (#17)
* make column_type a pydantic field rather than property * dynamically create the dd column type enum * remove unused report module * refine errors a bit * update action version * pr feedback
1 parent 9ccc772 commit 728e319

File tree

19 files changed

+309
-144
lines changed

19 files changed

+309
-144
lines changed

.github/workflows/dco-assistant.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
steps:
2828
- name: "DCO Assistant"
2929
if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the Contributor Agreement including DCO and I hereby sign the Contributor Agreement and DCO') || github.event_name == 'pull_request_target'
30-
uses: contributor-assistant/github-action@v2.6.1
30+
uses: contributor-assistant/github-action@ca4a40a7d1004f18d9960b404b97e5f30a505a08
3131
env:
3232
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
3333
PERSONAL_ACCESS_TOKEN: ${{ secrets.DCO_ASSISTANT_TOKEN }}

src/data_designer/config/analysis/dataset_profiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from pydantic import BaseModel, Field, field_validator
99

10-
from ..columns import DataDesignerColumnType
10+
from ..columns import DataDesignerColumnType, get_column_display_order
1111
from ..utils.constants import EPSILON
1212
from ..utils.numerical_helpers import prepare_number_for_reporting
1313
from .column_profilers import ColumnProfilerResultsT
@@ -32,7 +32,7 @@ def percent_complete(self) -> float:
3232

3333
@cached_property
3434
def column_types(self) -> list[str]:
35-
display_order = DataDesignerColumnType.get_display_order()
35+
display_order = get_column_display_order()
3636
return sorted(
3737
list(set([c.column_type for c in self.column_statistics])),
3838
key=lambda x: display_order.index(x) if x in display_order else len(display_order),

src/data_designer/config/analysis/utils/reporting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from rich.text import Text
1616

1717
from ...analysis.column_statistics import CategoricalHistogramData
18-
from ...columns import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType
18+
from ...columns import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType, get_column_display_order
1919
from ...utils.visualization import (
2020
ColorPalette,
2121
convert_to_row_element,
@@ -44,7 +44,7 @@ class ReportSection(str, Enum):
4444
DEFAULT_INCLUDE_SECTIONS = [
4545
ReportSection.OVERVIEW,
4646
ReportSection.COLUMN_PROFILERS,
47-
] + DataDesignerColumnType.get_display_order()
47+
] + get_column_display_order()
4848

4949

5050
def generate_analysis_report(

src/data_designer/config/columns.py

Lines changed: 69 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from abc import ABC, abstractmethod
5-
from enum import Enum
4+
from abc import ABC
65
from typing import Literal, Optional, Type, Union
76

87
from pydantic import BaseModel, Field, model_validator
@@ -15,56 +14,14 @@
1514
from .utils.code_lang import CodeLang
1615
from .utils.constants import REASONING_TRACE_COLUMN_POSTFIX
1716
from .utils.misc import assert_valid_jinja2_template, get_prompt_template_keywords
18-
from .utils.type_helpers import SAMPLER_PARAMS, resolve_string_enum
17+
from .utils.type_helpers import SAMPLER_PARAMS, create_str_enum_from_discriminated_type_union, resolve_string_enum
1918
from .validator_params import ValidatorParamsT, ValidatorType
2019

2120

22-
class DataDesignerColumnType(str, Enum):
23-
SAMPLER = "sampler"
24-
LLM_TEXT = "llm-text"
25-
LLM_CODE = "llm-code"
26-
LLM_STRUCTURED = "llm-structured"
27-
LLM_JUDGE = "llm-judge"
28-
EXPRESSION = "expression"
29-
VALIDATION = "validation"
30-
SEED_DATASET = "seed-dataset"
31-
32-
@staticmethod
33-
def get_display_order() -> list[Self]:
34-
return [
35-
DataDesignerColumnType.SEED_DATASET,
36-
DataDesignerColumnType.SAMPLER,
37-
DataDesignerColumnType.LLM_TEXT,
38-
DataDesignerColumnType.LLM_CODE,
39-
DataDesignerColumnType.LLM_STRUCTURED,
40-
DataDesignerColumnType.LLM_JUDGE,
41-
DataDesignerColumnType.VALIDATION,
42-
DataDesignerColumnType.EXPRESSION,
43-
]
44-
45-
@property
46-
def has_prompt_templates(self) -> bool:
47-
return self in [self.LLM_TEXT, self.LLM_CODE, self.LLM_STRUCTURED, self.LLM_JUDGE]
48-
49-
@property
50-
def is_dag_column_type(self) -> bool:
51-
return self in [
52-
self.EXPRESSION,
53-
self.LLM_CODE,
54-
self.LLM_JUDGE,
55-
self.LLM_STRUCTURED,
56-
self.LLM_TEXT,
57-
self.VALIDATION,
58-
]
59-
60-
6121
class SingleColumnConfig(ConfigBase, ABC):
6222
name: str
6323
drop: bool = False
64-
65-
@property
66-
@abstractmethod
67-
def column_type(self) -> DataDesignerColumnType: ...
24+
column_type: str
6825

6926
@property
7027
def required_columns(self) -> list[str]:
@@ -80,21 +37,15 @@ class SamplerColumnConfig(SingleColumnConfig):
8037
params: SamplerParamsT
8138
conditional_params: dict[str, SamplerParamsT] = {}
8239
convert_to: Optional[str] = None
83-
84-
@property
85-
def column_type(self) -> DataDesignerColumnType:
86-
return DataDesignerColumnType.SAMPLER
40+
column_type: Literal["sampler"] = "sampler"
8741

8842

8943
class LLMTextColumnConfig(SingleColumnConfig):
9044
prompt: str
9145
model_alias: str
9246
system_prompt: Optional[str] = None
9347
multi_modal_context: Optional[list[ImageContext]] = None
94-
95-
@property
96-
def column_type(self) -> DataDesignerColumnType:
97-
return DataDesignerColumnType.LLM_TEXT
48+
column_type: Literal["llm-text"] = "llm-text"
9849

9950
@property
10051
def required_columns(self) -> list[str]:
@@ -117,18 +68,12 @@ def assert_prompt_valid_jinja(self) -> Self:
11768

11869
class LLMCodeColumnConfig(LLMTextColumnConfig):
11970
code_lang: CodeLang
120-
121-
@property
122-
def column_type(self) -> DataDesignerColumnType:
123-
return DataDesignerColumnType.LLM_CODE
71+
column_type: Literal["llm-code"] = "llm-code"
12472

12573

12674
class LLMStructuredColumnConfig(LLMTextColumnConfig):
12775
output_format: Union[dict, Type[BaseModel]]
128-
129-
@property
130-
def column_type(self) -> DataDesignerColumnType:
131-
return DataDesignerColumnType.LLM_STRUCTURED
76+
column_type: Literal["llm-structured"] = "llm-structured"
13277

13378
@model_validator(mode="after")
13479
def validate_output_format(self) -> Self:
@@ -145,20 +90,14 @@ class Score(ConfigBase):
14590

14691
class LLMJudgeColumnConfig(LLMTextColumnConfig):
14792
scores: list[Score] = Field(..., min_length=1)
148-
149-
@property
150-
def column_type(self) -> DataDesignerColumnType:
151-
return DataDesignerColumnType.LLM_JUDGE
93+
column_type: Literal["llm-judge"] = "llm-judge"
15294

15395

15496
class ExpressionColumnConfig(SingleColumnConfig):
15597
name: str
15698
expr: str
15799
dtype: Literal["int", "float", "str", "bool"] = "str"
158-
159-
@property
160-
def column_type(self) -> DataDesignerColumnType:
161-
return DataDesignerColumnType.EXPRESSION
100+
column_type: Literal["expression"] = "expression"
162101

163102
@property
164103
def required_columns(self) -> list[str]:
@@ -168,7 +107,9 @@ def required_columns(self) -> list[str]:
168107
def assert_expression_valid_jinja(self) -> Self:
169108
if not self.expr.strip():
170109
raise InvalidConfigError(
171-
f"🛑 Expression column '{self.name}' has an empty or whitespace-only expression. Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') or remove this column if not needed."
110+
f"🛑 Expression column '{self.name}' has an empty or whitespace-only expression. "
111+
f"Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') "
112+
"or remove this column if not needed."
172113
)
173114
assert_valid_jinja2_template(self.expr)
174115
return self
@@ -179,20 +120,34 @@ class ValidationColumnConfig(SingleColumnConfig):
179120
validator_type: ValidatorType
180121
validator_params: ValidatorParamsT
181122
batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch")
182-
183-
@property
184-
def column_type(self) -> DataDesignerColumnType:
185-
return DataDesignerColumnType.VALIDATION
123+
column_type: Literal["validation"] = "validation"
186124

187125
@property
188126
def required_columns(self) -> list[str]:
189127
return self.target_columns
190128

191129

192130
class SeedDatasetColumnConfig(SingleColumnConfig):
193-
@property
194-
def column_type(self) -> DataDesignerColumnType:
195-
return DataDesignerColumnType.SEED_DATASET
131+
column_type: Literal["seed-dataset"] = "seed-dataset"
132+
133+
134+
ColumnConfigT: TypeAlias = Union[
135+
ExpressionColumnConfig,
136+
LLMCodeColumnConfig,
137+
LLMJudgeColumnConfig,
138+
LLMStructuredColumnConfig,
139+
LLMTextColumnConfig,
140+
SamplerColumnConfig,
141+
SeedDatasetColumnConfig,
142+
ValidationColumnConfig,
143+
]
144+
145+
146+
DataDesignerColumnType = create_str_enum_from_discriminated_type_union(
147+
enum_name="DataDesignerColumnType",
148+
type_union=ColumnConfigT,
149+
discriminator_field_name="column_type",
150+
)
196151

197152

198153
COLUMN_TYPE_EMOJI_MAP = {
@@ -208,16 +163,28 @@ def column_type(self) -> DataDesignerColumnType:
208163
}
209164

210165

211-
ColumnConfigT: TypeAlias = Union[
212-
ExpressionColumnConfig,
213-
LLMCodeColumnConfig,
214-
LLMJudgeColumnConfig,
215-
LLMStructuredColumnConfig,
216-
LLMTextColumnConfig,
217-
SamplerColumnConfig,
218-
SeedDatasetColumnConfig,
219-
ValidationColumnConfig,
220-
]
166+
def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumnType]) -> bool:
167+
"""Return True if the column type is used in the workflow execution DAG."""
168+
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
169+
return column_type in {
170+
DataDesignerColumnType.EXPRESSION,
171+
DataDesignerColumnType.LLM_CODE,
172+
DataDesignerColumnType.LLM_JUDGE,
173+
DataDesignerColumnType.LLM_STRUCTURED,
174+
DataDesignerColumnType.LLM_TEXT,
175+
DataDesignerColumnType.VALIDATION,
176+
}
177+
178+
179+
def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType]) -> bool:
180+
"""Return True if the column type is an LLM-generated column."""
181+
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
182+
return column_type in {
183+
DataDesignerColumnType.LLM_TEXT,
184+
DataDesignerColumnType.LLM_CODE,
185+
DataDesignerColumnType.LLM_STRUCTURED,
186+
DataDesignerColumnType.LLM_JUDGE,
187+
}
221188

222189

223190
def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT:
@@ -251,6 +218,20 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType
251218
raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover
252219

253220

221+
def get_column_display_order() -> list[DataDesignerColumnType]:
222+
"""Return the preferred display order of the column types."""
223+
return [
224+
DataDesignerColumnType.SEED_DATASET,
225+
DataDesignerColumnType.SAMPLER,
226+
DataDesignerColumnType.LLM_TEXT,
227+
DataDesignerColumnType.LLM_CODE,
228+
DataDesignerColumnType.LLM_STRUCTURED,
229+
DataDesignerColumnType.LLM_JUDGE,
230+
DataDesignerColumnType.VALIDATION,
231+
DataDesignerColumnType.EXPRESSION,
232+
]
233+
234+
254235
def _resolve_sampler_kwargs(name: str, kwargs: dict) -> dict:
255236
if "sampler_type" not in kwargs:
256237
raise InvalidConfigError(f"🛑 `sampler_type` is required for sampler column '{name}'.")

src/data_designer/config/config_builder.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515

1616
from .analysis.column_profilers import ColumnProfilerConfigT
1717
from .base import ExportableConfigBase
18-
from .columns import ColumnConfigT, DataDesignerColumnType, SeedDatasetColumnConfig, get_column_config_from_kwargs
18+
from .columns import (
19+
ColumnConfigT,
20+
DataDesignerColumnType,
21+
SeedDatasetColumnConfig,
22+
column_type_is_llm_generated,
23+
get_column_config_from_kwargs,
24+
)
1925
from .data_designer_config import DataDesignerConfig
2026
from .dataset_builders import BuildStage
2127
from .datastore import DatastoreSettings, fetch_seed_dataset_column_names
@@ -449,7 +455,7 @@ def get_llm_gen_columns(self) -> list[ColumnConfigT]:
449455
Returns:
450456
A list of column configurations that use LLM generation.
451457
"""
452-
return [c for c in self._column_configs.values() if c.column_type.has_prompt_templates]
458+
return [c for c in self._column_configs.values() if column_type_is_llm_generated(c.column_type)]
453459

454460
def get_columns_of_type(self, column_type: DataDesignerColumnType) -> list[ColumnConfigT]:
455461
"""Get all column configurations of the specified type.

src/data_designer/config/data_designer_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from __future__ import annotations
55

6-
from typing import Optional
6+
from typing import Annotated, Optional
77

88
from pydantic import Field
99

@@ -32,7 +32,7 @@ class DataDesignerConfig(ExportableConfigBase):
3232
profilers: Optional list of column profilers for analyzing generated data characteristics.
3333
"""
3434

35-
columns: list[ColumnConfigT] = Field(min_length=1)
35+
columns: list[Annotated[ColumnConfigT, Field(discriminator="column_type")]] = Field(min_length=1)
3636
model_configs: Optional[list[ModelConfig]] = None
3737
seed_config: Optional[SeedConfig] = None
3838
constraints: Optional[list[ColumnConstraintT]] = None

src/data_designer/config/utils/errors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,10 @@ class UserJinjaTemplateSyntaxError(DataDesignerError): ...
1010
class InvalidEnumValueError(DataDesignerError): ...
1111

1212

13+
class InvalidTypeUnionError(DataDesignerError): ...
14+
15+
16+
class InvalidDiscriminatorFieldError(DataDesignerError): ...
17+
18+
1319
class DatasetSampleDisplayError(DataDesignerError): ...

0 commit comments

Comments
 (0)