Skip to content

Commit bae308a

Browse files
Fix merge conflict
2 parents f4107fc + 3c337af commit bae308a

File tree

17 files changed

+406
-66
lines changed

17 files changed

+406
-66
lines changed

guardrails/datatypes.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,50 @@ class Percentage(ScalarType):
382382
tag = "percentage"
383383

384384

385+
@register_type("enum")
386+
class Enum(ScalarType):
387+
"""Element tag: `<enum>`"""
388+
389+
tag = "enum"
390+
391+
def __init__(
392+
self,
393+
children: Dict[str, Any],
394+
validators_attr: ValidatorsAttr,
395+
optional: bool,
396+
name: Optional[str],
397+
description: Optional[str],
398+
enum_values: TypedList[str],
399+
) -> None:
400+
super().__init__(children, validators_attr, optional, name, description)
401+
self.enum_values = enum_values
402+
403+
def from_str(self, s: str) -> Optional[str]:
404+
"""Create an Enum from a string."""
405+
if s is None:
406+
return None
407+
if s not in self.enum_values:
408+
raise ValueError(f"Invalid enum value: {s}")
409+
return s
410+
411+
@classmethod
412+
def from_xml(
413+
cls,
414+
enum_values: TypedList[str],
415+
validators: Sequence[ValidatorSpec],
416+
description: Optional[str] = None,
417+
strict: bool = False,
418+
) -> "Enum":
419+
return cls(
420+
children={},
421+
validators_attr=ValidatorsAttr.from_validators(validators, cls.tag, strict),
422+
optional=False,
423+
name=None,
424+
description=description,
425+
enum_values=enum_values,
426+
)
427+
428+
385429
@register_type("list")
386430
class List(NonScalarType):
387431
"""Element tag: `<list>`"""

guardrails/utils/json_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
DataType,
1212
Date,
1313
Email,
14+
Enum,
1415
Float,
1516
Integer,
1617
)
@@ -50,6 +51,7 @@ def verify(
5051
ListDataType: list,
5152
Date: str,
5253
Time: str,
54+
Enum: str,
5355
}
5456

5557
ignore_types = [

guardrails/utils/pydantic_utils/v1.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44
from copy import deepcopy
55
from datetime import date, time
6+
from enum import Enum
67
from typing import Any, Callable, Dict, Optional, Type, Union, get_args, get_origin
78

89
from pydantic import BaseModel, validator
@@ -14,6 +15,7 @@
1415
from guardrails.datatypes import Choice as ChoiceDataType
1516
from guardrails.datatypes import DataType
1617
from guardrails.datatypes import Date as DateDataType
18+
from guardrails.datatypes import Enum as EnumDataType
1719
from guardrails.datatypes import Float as FloatDataType
1820
from guardrails.datatypes import Integer as IntegerDataType
1921
from guardrails.datatypes import List as ListDataType
@@ -67,6 +69,19 @@ def is_dict(type_annotation: Any) -> bool:
6769
return False
6870

6971

72+
def is_enum(type_annotation: Any) -> bool:
73+
"""Check if a type_annotation is an enum."""
74+
75+
type_annotation = prepare_type_annotation(type_annotation)
76+
77+
try:
78+
if issubclass(type_annotation, Enum):
79+
return True
80+
except TypeError:
81+
pass
82+
return False
83+
84+
7085
def prepare_type_annotation(type_annotation: Union[ModelField, Type]) -> Type:
7186
"""Get the raw type annotation that can be used for downstream processing.
7287
@@ -262,6 +277,8 @@ def field_to_datatype(field: Union[ModelField, Type]) -> Type[DataType]:
262277
return ListDataType
263278
elif is_dict(type_annotation):
264279
return ObjectDataType
280+
elif is_enum(type_annotation):
281+
return EnumDataType
265282
elif type_annotation == bool:
266283
return BooleanDataType
267284
elif type_annotation == date:
@@ -356,6 +373,12 @@ def convert_pydantic_model_to_datatype(
356373
strict=strict,
357374
discriminator_key=discriminator,
358375
)
376+
elif target_datatype == EnumDataType:
377+
assert issubclass(type_annotation, Enum)
378+
valid_choices = [choice.value for choice in type_annotation]
379+
children[field_name] = pydantic_field_to_datatype(
380+
EnumDataType, field, strict=strict, enum_values=valid_choices
381+
)
359382
elif isinstance(field.type_, type) and issubclass(field.type_, BaseModel):
360383
children[field_name] = convert_pydantic_model_to_datatype(
361384
field, datatype=target_datatype, strict=strict

guardrails/utils/pydantic_utils/v2.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from copy import deepcopy
44
from datetime import date, time
5+
from enum import Enum
56
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, get_args
67

78
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
@@ -14,6 +15,7 @@
1415
from guardrails.datatypes import Choice as ChoiceDataType
1516
from guardrails.datatypes import DataType
1617
from guardrails.datatypes import Date as DateDataType
18+
from guardrails.datatypes import Enum as EnumDataType
1719
from guardrails.datatypes import Float as FloatDataType
1820
from guardrails.datatypes import Integer as IntegerDataType
1921
from guardrails.datatypes import List as ListDataType
@@ -88,6 +90,19 @@ def is_dict(type_annotation: Any) -> bool:
8890
return False
8991

9092

93+
def is_enum(type_annotation: Any) -> bool:
94+
"""Check if a type_annotation is an enum."""
95+
96+
type_annotation = prepare_type_annotation(type_annotation)
97+
98+
try:
99+
if issubclass(type_annotation, Enum):
100+
return True
101+
except TypeError:
102+
pass
103+
return False
104+
105+
91106
def _create_bare_model(model: Type[BaseModel]) -> Type[BaseModel]:
92107
class BareModel(BaseModel):
93108
__annotations__ = getattr(model, "__annotations__", {})
@@ -277,6 +292,8 @@ def field_to_datatype(field: Union[FieldInfo, Type]) -> Type[DataType]:
277292
return ListDataType
278293
elif is_dict(type_annotation):
279294
return ObjectDataType
295+
elif is_enum(type_annotation):
296+
return EnumDataType
280297
elif type_annotation == bool:
281298
return BooleanDataType
282299
elif type_annotation == date:
@@ -382,6 +399,16 @@ def convert_pydantic_model_to_datatype(
382399
discriminator_key=discriminator,
383400
name=field_name,
384401
)
402+
elif target_datatype == EnumDataType:
403+
assert issubclass(type_annotation, Enum)
404+
valid_choices = [choice.value for choice in type_annotation]
405+
children[field_name] = pydantic_field_to_datatype(
406+
EnumDataType,
407+
field,
408+
strict=strict,
409+
enum_values=valid_choices,
410+
name=field_name,
411+
)
385412
elif is_pydantic_base_model(field.annotation):
386413
children[field_name] = convert_pydantic_model_to_datatype(
387414
field,

0 commit comments

Comments
 (0)