Skip to content

Commit 8caef28

Browse files
committed
dynamically create the dd column type enum
1 parent 772853b commit 8caef28

File tree

3 files changed

+171
-27
lines changed

3 files changed

+171
-27
lines changed

src/data_designer/config/columns.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from abc import ABC
5-
from enum import Enum
65
from typing import Literal, Optional, Type, Union
76

87
from pydantic import BaseModel, Field, model_validator
@@ -15,21 +14,10 @@
1514
from .utils.code_lang import CodeLang
1615
from .utils.constants import REASONING_TRACE_COLUMN_POSTFIX
1716
from .utils.misc import assert_valid_jinja2_template, get_prompt_template_keywords
18-
from .utils.type_helpers import SAMPLER_PARAMS, resolve_string_enum
17+
from .utils.type_helpers import SAMPLER_PARAMS, create_str_enum_from_discriminated_type_union, resolve_string_enum
1918
from .validator_params import ValidatorParamsT, ValidatorType
2019

2120

22-
class DataDesignerColumnType(str, Enum):
23-
SAMPLER = "sampler"
24-
LLM_TEXT = "llm-text"
25-
LLM_CODE = "llm-code"
26-
LLM_STRUCTURED = "llm-structured"
27-
LLM_JUDGE = "llm-judge"
28-
EXPRESSION = "expression"
29-
VALIDATION = "validation"
30-
SEED_DATASET = "seed-dataset"
31-
32-
3321
class SingleColumnConfig(ConfigBase, ABC):
3422
name: str
3523
drop: bool = False
@@ -143,6 +131,25 @@ class SeedDatasetColumnConfig(SingleColumnConfig):
143131
column_type: Literal["seed-dataset"] = "seed-dataset"
144132

145133

134+
ColumnConfigT: TypeAlias = Union[
135+
ExpressionColumnConfig,
136+
LLMCodeColumnConfig,
137+
LLMJudgeColumnConfig,
138+
LLMStructuredColumnConfig,
139+
LLMTextColumnConfig,
140+
SamplerColumnConfig,
141+
SeedDatasetColumnConfig,
142+
ValidationColumnConfig,
143+
]
144+
145+
146+
DataDesignerColumnType = create_str_enum_from_discriminated_type_union(
147+
enum_name="DataDesignerColumnType",
148+
type_union=ColumnConfigT,
149+
discriminator_field_name="column_type",
150+
)
151+
152+
146153
COLUMN_TYPE_EMOJI_MAP = {
147154
"general": "⚛️", # possible analysis column type
148155
DataDesignerColumnType.EXPRESSION: "🧩",
@@ -156,18 +163,6 @@ class SeedDatasetColumnConfig(SingleColumnConfig):
156163
}
157164

158165

159-
ColumnConfigT: TypeAlias = Union[
160-
ExpressionColumnConfig,
161-
LLMCodeColumnConfig,
162-
LLMJudgeColumnConfig,
163-
LLMStructuredColumnConfig,
164-
LLMTextColumnConfig,
165-
SamplerColumnConfig,
166-
SeedDatasetColumnConfig,
167-
ValidationColumnConfig,
168-
]
169-
170-
171166
def column_type_is_in_dag(column_type: Union[str, DataDesignerColumnType]) -> bool:
172167
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
173168
return column_type in [

src/data_designer/config/utils/type_helpers.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,53 @@
33

44
from enum import Enum
55
import inspect
6-
from typing import Any, Type
6+
from typing import Any, Type, Union, get_args
77

88
from pydantic import BaseModel
99

1010
from .. import sampler_params
1111
from .errors import InvalidEnumValueError
1212

1313

14+
class StrEnum(str, Enum):
15+
pass
16+
17+
18+
def create_str_enum_from_discriminated_type_union(
19+
enum_name: str,
20+
type_union: Type[Union[BaseModel, ...]],
21+
discriminator_field_name: str,
22+
) -> StrEnum:
23+
"""Create a string enum from a type union.
24+
25+
The type union is assumed to be a union of configs (Pydantic models) that have a discriminator field,
26+
which must be a Literal string type - e.g., Literal["expression"].
27+
28+
Args:
29+
enum_name: Name of the StrEnum.
30+
type_union: Type union of configs (Pydantic models).
31+
discriminator_field_name: Name of the discriminator field.
32+
33+
Returns:
34+
StrEnum with values being the discriminator field values of the configs in the type union.
35+
36+
Example:
37+
DataDesignerColumnType = create_str_enum_from_discriminated_type_union(
38+
enum_name="DataDesignerColumnType",
39+
type_union=ColumnConfigT,
40+
discriminator_field_name="column_type",
41+
)
42+
"""
43+
discriminator_field_values = []
44+
for model in type_union.__args__:
45+
if not issubclass(model, BaseModel):
46+
raise ValueError(f"🛑 {model} is not a Pydantic model.")
47+
if discriminator_field_name not in model.model_fields:
48+
raise ValueError(f"🛑 {discriminator_field_name} is not a field of {model}.")
49+
discriminator_field_values.extend(get_args(model.model_fields[discriminator_field_name].annotation))
50+
return StrEnum(enum_name, {v.replace("-", "_").upper(): v for v in set(discriminator_field_values)})
51+
52+
1453
def get_sampler_params() -> dict[str, Type[BaseModel]]:
1554
"""Returns a dictionary of sampler parameter classes."""
1655
params_cls_list = [

tests/config/utils/test_type_helpers.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,127 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from enum import Enum
5+
from typing import Literal, Union
56

7+
from pydantic import BaseModel
68
import pytest
79

810
from data_designer.config.utils.errors import InvalidEnumValueError
9-
from data_designer.config.utils.type_helpers import SAMPLER_PARAMS, get_sampler_params, resolve_string_enum
11+
from data_designer.config.utils.type_helpers import (
12+
SAMPLER_PARAMS,
13+
create_str_enum_from_discriminated_type_union,
14+
get_sampler_params,
15+
resolve_string_enum,
16+
)
1017

1118

1219
class StubTestEnum(str, Enum):
1320
TEST = "test"
1421

1522

23+
class StubModelA(BaseModel):
24+
column_type: Literal["type-a", "type-a-alt"] = "type-a"
25+
name: str
26+
27+
28+
class StubModelB(BaseModel):
29+
column_type: Literal["type-b"] = "type-b"
30+
value: int
31+
32+
33+
class StubModelC(BaseModel):
34+
column_type: Literal["type-c-with-dashes"] = "type-c-with-dashes"
35+
data: str
36+
37+
38+
class StubModelWithoutDiscriminator(BaseModel):
39+
name: str
40+
value: int
41+
42+
43+
class NotAModel:
44+
column_type: str = "not-a-model"
45+
46+
47+
def test_create_str_enum_from_type_union_basic() -> None:
48+
type_union = Union[StubModelA, StubModelB]
49+
result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type")
50+
51+
assert issubclass(result, Enum)
52+
assert issubclass(result, str)
53+
assert hasattr(result, "TYPE_A")
54+
assert hasattr(result, "TYPE_A_ALT")
55+
assert hasattr(result, "TYPE_B")
56+
assert result.TYPE_A.value == "type-a"
57+
assert result.TYPE_A_ALT.value == "type-a-alt"
58+
assert result.TYPE_B.value == "type-b"
59+
assert len(result) == 3
60+
61+
62+
def test_create_str_enum_from_type_union_with_dashes() -> None:
63+
type_union = Union[StubModelC, StubModelA]
64+
result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type")
65+
66+
assert hasattr(result, "TYPE_C_WITH_DASHES")
67+
assert result.TYPE_C_WITH_DASHES.value == "type-c-with-dashes"
68+
69+
70+
def test_create_str_enum_from_type_union_multiple_models() -> None:
71+
type_union = Union[StubModelA, StubModelB, StubModelC]
72+
result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type")
73+
74+
assert len(result) == 4
75+
assert hasattr(result, "TYPE_A")
76+
assert hasattr(result, "TYPE_A_ALT")
77+
assert hasattr(result, "TYPE_B")
78+
assert hasattr(result, "TYPE_C_WITH_DASHES")
79+
80+
81+
def test_create_str_enum_from_type_union_duplicate_values() -> None:
82+
class StubModelD(BaseModel):
83+
column_type: Literal["type-a"] = "type-a"
84+
extra: str
85+
86+
type_union = Union[StubModelA, StubModelD]
87+
result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type")
88+
89+
assert len(result) == 2
90+
assert hasattr(result, "TYPE_A")
91+
assert hasattr(result, "TYPE_A_ALT")
92+
93+
94+
def test_create_str_enum_from_type_union_not_pydantic_model() -> None:
95+
type_union = Union[StubModelA, NotAModel]
96+
97+
with pytest.raises(ValueError, match="is not a Pydantic model"):
98+
create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type")
99+
100+
101+
def test_create_str_enum_from_type_union_missing_discriminator_field() -> None:
102+
type_union = Union[StubModelA, StubModelWithoutDiscriminator]
103+
104+
with pytest.raises(ValueError, match="column_type is not a field of"):
105+
create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type")
106+
107+
108+
def test_create_str_enum_from_type_union_custom_discriminator_name() -> None:
109+
class StubModelE(BaseModel):
110+
type_field: Literal["custom-type"] = "custom-type"
111+
name: str
112+
113+
class StubModelF(BaseModel):
114+
type_field: Literal["another-type"] = "another-type"
115+
value: int
116+
117+
type_union = Union[StubModelE, StubModelF]
118+
result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "type_field")
119+
120+
assert hasattr(result, "CUSTOM_TYPE")
121+
assert result.CUSTOM_TYPE.value == "custom-type"
122+
assert hasattr(result, "ANOTHER_TYPE")
123+
assert result.ANOTHER_TYPE.value == "another-type"
124+
125+
16126
def test_get_sampler_params():
17127
expected_sampler_keys = {
18128
"bernoulli",

0 commit comments

Comments
 (0)