11# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: Apache-2.0
33
4- from abc import ABC , abstractmethod
5- from enum import Enum
4+ from abc import ABC
65from typing import Literal , Optional , Type , Union
76
87from pydantic import BaseModel , Field , model_validator
1514from .utils .code_lang import CodeLang
1615from .utils .constants import REASONING_TRACE_COLUMN_POSTFIX
1716from .utils .misc import assert_valid_jinja2_template , get_prompt_template_keywords
18- from .utils .type_helpers import SAMPLER_PARAMS , resolve_string_enum
17+ from .utils .type_helpers import SAMPLER_PARAMS , create_str_enum_from_discriminated_type_union , resolve_string_enum
1918from .validator_params import ValidatorParamsT , ValidatorType
2019
2120
22- class DataDesignerColumnType (str , Enum ):
23- SAMPLER = "sampler"
24- LLM_TEXT = "llm-text"
25- LLM_CODE = "llm-code"
26- LLM_STRUCTURED = "llm-structured"
27- LLM_JUDGE = "llm-judge"
28- EXPRESSION = "expression"
29- VALIDATION = "validation"
30- SEED_DATASET = "seed-dataset"
31-
32- @staticmethod
33- def get_display_order () -> list [Self ]:
34- return [
35- DataDesignerColumnType .SEED_DATASET ,
36- DataDesignerColumnType .SAMPLER ,
37- DataDesignerColumnType .LLM_TEXT ,
38- DataDesignerColumnType .LLM_CODE ,
39- DataDesignerColumnType .LLM_STRUCTURED ,
40- DataDesignerColumnType .LLM_JUDGE ,
41- DataDesignerColumnType .VALIDATION ,
42- DataDesignerColumnType .EXPRESSION ,
43- ]
44-
45- @property
46- def has_prompt_templates (self ) -> bool :
47- return self in [self .LLM_TEXT , self .LLM_CODE , self .LLM_STRUCTURED , self .LLM_JUDGE ]
48-
49- @property
50- def is_dag_column_type (self ) -> bool :
51- return self in [
52- self .EXPRESSION ,
53- self .LLM_CODE ,
54- self .LLM_JUDGE ,
55- self .LLM_STRUCTURED ,
56- self .LLM_TEXT ,
57- self .VALIDATION ,
58- ]
59-
60-
6121class SingleColumnConfig (ConfigBase , ABC ):
6222 name : str
6323 drop : bool = False
64-
65- @property
66- @abstractmethod
67- def column_type (self ) -> DataDesignerColumnType : ...
24+ column_type : str
6825
6926 @property
7027 def required_columns (self ) -> list [str ]:
@@ -80,21 +37,15 @@ class SamplerColumnConfig(SingleColumnConfig):
8037 params : SamplerParamsT
8138 conditional_params : dict [str , SamplerParamsT ] = {}
8239 convert_to : Optional [str ] = None
83-
84- @property
85- def column_type (self ) -> DataDesignerColumnType :
86- return DataDesignerColumnType .SAMPLER
40+ column_type : Literal ["sampler" ] = "sampler"
8741
8842
8943class LLMTextColumnConfig (SingleColumnConfig ):
9044 prompt : str
9145 model_alias : str
9246 system_prompt : Optional [str ] = None
9347 multi_modal_context : Optional [list [ImageContext ]] = None
94-
95- @property
96- def column_type (self ) -> DataDesignerColumnType :
97- return DataDesignerColumnType .LLM_TEXT
48+ column_type : Literal ["llm-text" ] = "llm-text"
9849
9950 @property
10051 def required_columns (self ) -> list [str ]:
@@ -117,18 +68,12 @@ def assert_prompt_valid_jinja(self) -> Self:
11768
11869class LLMCodeColumnConfig (LLMTextColumnConfig ):
11970 code_lang : CodeLang
120-
121- @property
122- def column_type (self ) -> DataDesignerColumnType :
123- return DataDesignerColumnType .LLM_CODE
71+ column_type : Literal ["llm-code" ] = "llm-code"
12472
12573
12674class LLMStructuredColumnConfig (LLMTextColumnConfig ):
12775 output_format : Union [dict , Type [BaseModel ]]
128-
129- @property
130- def column_type (self ) -> DataDesignerColumnType :
131- return DataDesignerColumnType .LLM_STRUCTURED
76+ column_type : Literal ["llm-structured" ] = "llm-structured"
13277
13378 @model_validator (mode = "after" )
13479 def validate_output_format (self ) -> Self :
@@ -145,20 +90,14 @@ class Score(ConfigBase):
14590
14691class LLMJudgeColumnConfig (LLMTextColumnConfig ):
14792 scores : list [Score ] = Field (..., min_length = 1 )
148-
149- @property
150- def column_type (self ) -> DataDesignerColumnType :
151- return DataDesignerColumnType .LLM_JUDGE
93+ column_type : Literal ["llm-judge" ] = "llm-judge"
15294
15395
15496class ExpressionColumnConfig (SingleColumnConfig ):
15597 name : str
15698 expr : str
15799 dtype : Literal ["int" , "float" , "str" , "bool" ] = "str"
158-
159- @property
160- def column_type (self ) -> DataDesignerColumnType :
161- return DataDesignerColumnType .EXPRESSION
100+ column_type : Literal ["expression" ] = "expression"
162101
163102 @property
164103 def required_columns (self ) -> list [str ]:
@@ -168,7 +107,9 @@ def required_columns(self) -> list[str]:
168107 def assert_expression_valid_jinja (self ) -> Self :
169108 if not self .expr .strip ():
170109 raise InvalidConfigError (
171- f"🛑 Expression column '{ self .name } ' has an empty or whitespace-only expression. Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') or remove this column if not needed."
110+ f"🛑 Expression column '{ self .name } ' has an empty or whitespace-only expression. "
111+ f"Please provide a valid Jinja2 expression (e.g., '{{ column_name }}' or '{{ col1 }} + {{ col2 }}') "
112+ "or remove this column if not needed."
172113 )
173114 assert_valid_jinja2_template (self .expr )
174115 return self
@@ -179,20 +120,34 @@ class ValidationColumnConfig(SingleColumnConfig):
179120 validator_type : ValidatorType
180121 validator_params : ValidatorParamsT
181122 batch_size : int = Field (default = 10 , ge = 1 , description = "Number of records to process in each batch" )
182-
183- @property
184- def column_type (self ) -> DataDesignerColumnType :
185- return DataDesignerColumnType .VALIDATION
123+ column_type : Literal ["validation" ] = "validation"
186124
187125 @property
188126 def required_columns (self ) -> list [str ]:
189127 return self .target_columns
190128
191129
192130class SeedDatasetColumnConfig (SingleColumnConfig ):
193- @property
194- def column_type (self ) -> DataDesignerColumnType :
195- return DataDesignerColumnType .SEED_DATASET
131+ column_type : Literal ["seed-dataset" ] = "seed-dataset"
132+
133+
134+ ColumnConfigT : TypeAlias = Union [
135+ ExpressionColumnConfig ,
136+ LLMCodeColumnConfig ,
137+ LLMJudgeColumnConfig ,
138+ LLMStructuredColumnConfig ,
139+ LLMTextColumnConfig ,
140+ SamplerColumnConfig ,
141+ SeedDatasetColumnConfig ,
142+ ValidationColumnConfig ,
143+ ]
144+
145+
146+ DataDesignerColumnType = create_str_enum_from_discriminated_type_union (
147+ enum_name = "DataDesignerColumnType" ,
148+ type_union = ColumnConfigT ,
149+ discriminator_field_name = "column_type" ,
150+ )
196151
197152
198153COLUMN_TYPE_EMOJI_MAP = {
@@ -208,16 +163,28 @@ def column_type(self) -> DataDesignerColumnType:
208163}
209164
210165
211- ColumnConfigT : TypeAlias = Union [
212- ExpressionColumnConfig ,
213- LLMCodeColumnConfig ,
214- LLMJudgeColumnConfig ,
215- LLMStructuredColumnConfig ,
216- LLMTextColumnConfig ,
217- SamplerColumnConfig ,
218- SeedDatasetColumnConfig ,
219- ValidationColumnConfig ,
220- ]
166+ def column_type_used_in_execution_dag (column_type : Union [str , DataDesignerColumnType ]) -> bool :
167+ """Return True if the column type is used in the workflow execution DAG."""
168+ column_type = resolve_string_enum (column_type , DataDesignerColumnType )
169+ return column_type in {
170+ DataDesignerColumnType .EXPRESSION ,
171+ DataDesignerColumnType .LLM_CODE ,
172+ DataDesignerColumnType .LLM_JUDGE ,
173+ DataDesignerColumnType .LLM_STRUCTURED ,
174+ DataDesignerColumnType .LLM_TEXT ,
175+ DataDesignerColumnType .VALIDATION ,
176+ }
177+
178+
179+ def column_type_is_llm_generated (column_type : Union [str , DataDesignerColumnType ]) -> bool :
180+ """Return True if the column type is an LLM-generated column."""
181+ column_type = resolve_string_enum (column_type , DataDesignerColumnType )
182+ return column_type in {
183+ DataDesignerColumnType .LLM_TEXT ,
184+ DataDesignerColumnType .LLM_CODE ,
185+ DataDesignerColumnType .LLM_STRUCTURED ,
186+ DataDesignerColumnType .LLM_JUDGE ,
187+ }
221188
222189
223190def get_column_config_from_kwargs (name : str , column_type : DataDesignerColumnType , ** kwargs ) -> ColumnConfigT :
@@ -251,6 +218,20 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType
251218 raise InvalidColumnTypeError (f"🛑 { column_type } is not a valid column type." ) # pragma: no cover
252219
253220
221+ def get_column_display_order () -> list [DataDesignerColumnType ]:
222+ """Return the preferred display order of the column types."""
223+ return [
224+ DataDesignerColumnType .SEED_DATASET ,
225+ DataDesignerColumnType .SAMPLER ,
226+ DataDesignerColumnType .LLM_TEXT ,
227+ DataDesignerColumnType .LLM_CODE ,
228+ DataDesignerColumnType .LLM_STRUCTURED ,
229+ DataDesignerColumnType .LLM_JUDGE ,
230+ DataDesignerColumnType .VALIDATION ,
231+ DataDesignerColumnType .EXPRESSION ,
232+ ]
233+
234+
254235def _resolve_sampler_kwargs (name : str , kwargs : dict ) -> dict :
255236 if "sampler_type" not in kwargs :
256237 raise InvalidConfigError (f"🛑 `sampler_type` is required for sampler column '{ name } '." )
0 commit comments