Skip to content

Commit 07bf10c

Browse files
author
KunxiSun
committed
feat: add IntEnum for sqltypes
1 parent 6c0410e commit 07bf10c

File tree

4 files changed

+87
-9
lines changed

4 files changed

+87
-9
lines changed

sqlmodel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,4 @@
141141
from .sql.expression import type_coerce as type_coerce
142142
from .sql.expression import within_group as within_group
143143
from .sql.sqltypes import AutoString as AutoString
144+
from .sql.sqltypes import IntEnum as IntEnum

sqlmodel/sql/sqltypes.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, cast
1+
from typing import Any, cast, Optional
2+
from enum import IntEnum as _IntEnum
23

34
from sqlalchemy import types
45
from sqlalchemy.engine.interfaces import Dialect
@@ -14,3 +15,57 @@ def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]":
1415
if impl.length is None and dialect.name == "mysql":
1516
return dialect.type_descriptor(types.String(self.mysql_default_length))
1617
return super().load_dialect_impl(dialect)
18+
19+
class IntEnum(types.TypeDecorator): # type: ignore
20+
"""TypeDecorator for Integer-enum conversion.
21+
22+
Automatically converts Python enum.IntEnum <-> database integers.
23+
24+
Args:
25+
enum_type (enum.IntEnum): Integer enum class (subclass of enum.IntEnum)
26+
27+
Example:
28+
>>> class HeroStatus(enum.IntEnum):
29+
... ACTIVE = 1
30+
... DISABLE = 2
31+
>>>>
32+
>>> from sqlmodel import IntEnum
33+
>>> class Hero(SQLModel):
34+
... hero_status: HeroStatus = Field(sa_type=sqlmodel.IntEnum(HeroStatus))
35+
>>> user.hero_status == Status.ACTIVE # Loads back as enum
36+
37+
Returns:
38+
Optional[enum.IntEnum]: Converted enum instance (None if database value is NULL)
39+
40+
Raises:
41+
TypeError: For invalid enum types
42+
"""
43+
44+
impl = types.Integer
45+
46+
def __init__(self, enum_type: _IntEnum, *args, **kwargs):
47+
super().__init__(*args, **kwargs)
48+
49+
# validate the input enum type
50+
if not issubclass(enum_type, _IntEnum):
51+
raise TypeError(
52+
f"Input must be enum.IntEnum"
53+
)
54+
55+
self.enum_type = enum_type
56+
57+
def process_result_value(self, value: Optional[int], dialect) -> Optional[_IntEnum]:
58+
59+
if value is None:
60+
return None
61+
62+
result = self.enum_type(value)
63+
return result
64+
65+
def process_bind_param(self, value: Optional[_IntEnum], dialect) -> Optional[int]:
66+
67+
if value is None:
68+
return None
69+
70+
result = value.value
71+
return result

tests/test_enums.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,21 @@ def test_json_schema_flat_model_pydantic_v1():
6363
"properties": {
6464
"id": {"title": "Id", "type": "string", "format": "uuid"},
6565
"enum_field": {"$ref": "#/definitions/MyEnum1"},
66+
"int_enum_field": {"$ref": "#/definitions/MyEnum3"},
6667
},
67-
"required": ["id", "enum_field"],
68+
"required": ["id", "enum_field", "int_enum_field"],
6869
"definitions": {
6970
"MyEnum1": {
7071
"title": "MyEnum1",
7172
"description": "An enumeration.",
7273
"enum": ["A", "B"],
7374
"type": "string",
75+
},
76+
"MyEnum3": {
77+
"title": "MyEnum3",
78+
"description": "An enumeration.",
79+
"enum": [1, 3],
80+
"type": "int",
7481
}
7582
},
7683
}
@@ -84,14 +91,21 @@ def test_json_schema_inherit_model_pydantic_v1():
8491
"properties": {
8592
"id": {"title": "Id", "type": "string", "format": "uuid"},
8693
"enum_field": {"$ref": "#/definitions/MyEnum2"},
94+
"int_enum_field": {"$ref": "#/definitions/MyEnum3"},
8795
},
88-
"required": ["id", "enum_field"],
96+
"required": ["id", "enum_field", "int_enum_field"],
8997
"definitions": {
9098
"MyEnum2": {
9199
"title": "MyEnum2",
92100
"description": "An enumeration.",
93101
"enum": ["C", "D"],
94102
"type": "string",
103+
},
104+
"MyEnum3": {
105+
"title": "MyEnum3",
106+
"description": "An int enumeration.",
107+
"enum": [1, 3],
108+
"type": "int",
95109
}
96110
},
97111
}
@@ -105,10 +119,12 @@ def test_json_schema_flat_model_pydantic_v2():
105119
"properties": {
106120
"id": {"title": "Id", "type": "string", "format": "uuid"},
107121
"enum_field": {"$ref": "#/$defs/MyEnum1"},
122+
"int_enum_field": {"$ref": "#/$defs/MyEnum3"},
108123
},
109-
"required": ["id", "enum_field"],
124+
"required": ["id", "enum_field", "int_enum_field"],
110125
"$defs": {
111-
"MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"}
126+
"MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"},
127+
"MyEnum3": {"enum": [1, 2], "title": "MyEnum3", "type": "integer"},
112128
},
113129
}
114130

@@ -121,9 +137,11 @@ def test_json_schema_inherit_model_pydantic_v2():
121137
"properties": {
122138
"id": {"title": "Id", "type": "string", "format": "uuid"},
123139
"enum_field": {"$ref": "#/$defs/MyEnum2"},
140+
"int_enum_field": {"$ref": "#/$defs/MyEnum3"},
124141
},
125-
"required": ["id", "enum_field"],
142+
"required": ["id", "enum_field", "int_enum_field"],
126143
"$defs": {
127-
"MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"}
144+
"MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"},
145+
"MyEnum3": {"enum": [1, 2], "title": "MyEnum3", "type": "integer"},
128146
},
129147
}

tests/test_enums_models.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import enum
22
import uuid
33

4-
from sqlmodel import Field, SQLModel
4+
from sqlmodel import Field, SQLModel, IntEnum
55

66

77
class MyEnum1(str, enum.Enum):
@@ -13,15 +13,19 @@ class MyEnum2(str, enum.Enum):
1313
C = "C"
1414
D = "D"
1515

16+
class MyEnum3(enum.IntEnum):
17+
E = 1
18+
F = 2
1619

1720
class BaseModel(SQLModel):
1821
id: uuid.UUID = Field(primary_key=True)
1922
enum_field: MyEnum2
20-
23+
int_enum_field: MyEnum3
2124

2225
class FlatModel(SQLModel, table=True):
2326
id: uuid.UUID = Field(primary_key=True)
2427
enum_field: MyEnum1
28+
int_enum_field: MyEnum3 = Field(sa_type=IntEnum(MyEnum3))
2529

2630

2731
class InheritModel(BaseModel, table=True):

0 commit comments

Comments
 (0)