Skip to content

Commit a436172

Browse files
committed
refine errors a bit
1 parent 8c5a604 commit a436172

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

src/data_designer/config/utils/errors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,10 @@ class UserJinjaTemplateSyntaxError(DataDesignerError): ...
1010
class InvalidEnumValueError(DataDesignerError): ...
1111

1212

13+
class InvalidTypeUnionError(DataDesignerError): ...
14+
15+
16+
class InvalidDiscriminatorFieldError(DataDesignerError): ...
17+
18+
1319
class DatasetSampleDisplayError(DataDesignerError): ...

src/data_designer/config/utils/type_helpers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44
from enum import Enum
55
import inspect
6-
from typing import Any, Type, Union, get_args
6+
from typing import Any, Literal, Type, Union, get_args, get_origin
77

88
from pydantic import BaseModel
99

1010
from .. import sampler_params
11-
from .errors import InvalidEnumValueError
11+
from .errors import InvalidDiscriminatorFieldError, InvalidEnumValueError, InvalidTypeUnionError
1212

1313

1414
class StrEnum(str, Enum):
@@ -43,9 +43,11 @@ def create_str_enum_from_discriminated_type_union(
4343
discriminator_field_values = []
4444
for model in type_union.__args__:
4545
if not issubclass(model, BaseModel):
46-
raise ValueError(f"🛑 {model} is not a Pydantic model.")
46+
raise InvalidTypeUnionError(f"🛑 {model} must be a subclass of pydantic.BaseModel.")
4747
if discriminator_field_name not in model.model_fields:
48-
raise ValueError(f"🛑 {discriminator_field_name} is not a field of {model}.")
48+
raise InvalidDiscriminatorFieldError(f"🛑 '{discriminator_field_name}' is not a field of {model}.")
49+
if get_origin(model.model_fields[discriminator_field_name].annotation) is not Literal:
50+
raise InvalidDiscriminatorFieldError(f"🛑 '{discriminator_field_name}' must be a Literal type.")
4951
discriminator_field_values.extend(get_args(model.model_fields[discriminator_field_name].annotation))
5052
return StrEnum(enum_name, {v.replace("-", "_").upper(): v for v in set(discriminator_field_values)})
5153

tests/config/utils/test_type_helpers.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from pydantic import BaseModel
88
import pytest
99

10-
from data_designer.config.utils.errors import InvalidEnumValueError
10+
from data_designer.config.utils.errors import (
11+
InvalidDiscriminatorFieldError,
12+
InvalidEnumValueError,
13+
InvalidTypeUnionError,
14+
)
1115
from data_designer.config.utils.type_helpers import (
1216
SAMPLER_PARAMS,
1317
create_str_enum_from_discriminated_type_union,
@@ -94,16 +98,19 @@ class StubModelD(BaseModel):
9498
def test_create_str_enum_from_type_union_not_pydantic_model() -> None:
9599
type_union = Union[StubModelA, NotAModel]
96100

97-
with pytest.raises(ValueError, match="is not a Pydantic model"):
101+
with pytest.raises(InvalidTypeUnionError, match="must be a subclass of pydantic.BaseModel"):
98102
create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type")
99103

100104

101-
def test_create_str_enum_from_type_union_missing_discriminator_field() -> None:
105+
def test_create_str_enum_from_type_union_invalid_discriminator_field() -> None:
102106
type_union = Union[StubModelA, StubModelWithoutDiscriminator]
103107

104-
with pytest.raises(ValueError, match="column_type is not a field of"):
108+
with pytest.raises(InvalidDiscriminatorFieldError, match="'column_type' is not a field of"):
105109
create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type")
106110

111+
with pytest.raises(InvalidDiscriminatorFieldError, match="'name' must be a Literal type"):
112+
create_str_enum_from_discriminated_type_union("TestEnum", type_union, "name")
113+
107114

108115
def test_create_str_enum_from_type_union_custom_discriminator_name() -> None:
109116
class StubModelE(BaseModel):

0 commit comments

Comments
 (0)