Skip to content

Commit a87c097

Browse files
authored
feat: add Enum to default type decoders (#397)
Extends the default type decoders to handle Enum types by converting them to their underlying value during serialization
1 parent ade619b commit a87c097

File tree

9 files changed

+344
-222
lines changed

9 files changed

+344
-222
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
- id: unasyncd
2323
additional_dependencies: ["ruff"]
2424
- repo: https://github.com/charliermarsh/ruff-pre-commit
25-
rev: "v0.9.6"
25+
rev: "v0.9.7"
2626
hooks:
2727
# Run the linter.
2828
- id: ruff

advanced_alchemy/_serialization.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,77 @@
1+
import datetime
2+
import enum
13
from typing import Any
24

5+
from typing_extensions import runtime_checkable
6+
7+
try:
8+
from pydantic import BaseModel # type: ignore # noqa: PGH003
9+
10+
PYDANTIC_INSTALLED = True
11+
except ImportError:
12+
from typing import ClassVar, Protocol
13+
14+
@runtime_checkable
15+
class BaseModel(Protocol): # type: ignore[no-redef]
16+
"""Placeholder Implementation"""
17+
18+
model_fields: ClassVar[dict[str, Any]]
19+
20+
def model_dump_json(self, *args: Any, **kwargs: Any) -> str:
21+
"""Placeholder"""
22+
return ""
23+
24+
PYDANTIC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
25+
26+
27+
def _type_to_string(value: Any) -> str: # pragma: no cover
28+
if isinstance(value, datetime.datetime):
29+
return convert_datetime_to_gmt_iso(value)
30+
if isinstance(value, datetime.date):
31+
return convert_date_to_iso(value)
32+
if isinstance(value, enum.Enum):
33+
return str(value.value)
34+
if PYDANTIC_INSTALLED and isinstance(value, BaseModel):
35+
return value.model_dump_json()
36+
try:
37+
val = str(value)
38+
except Exception as exc:
39+
raise TypeError from exc
40+
return val
41+
42+
343
try:
444
from msgspec.json import Decoder, Encoder
545

6-
encoder, decoder = Encoder(), Decoder()
46+
encoder, decoder = Encoder(enc_hook=_type_to_string), Decoder()
747
decode_json = decoder.decode
848

9-
def encode_json(data: Any) -> str:
49+
def encode_json(data: Any) -> str: # pragma: no cover
1050
return encoder.encode(data).decode("utf-8")
1151

1252
except ImportError:
1353
try:
54+
from orjson import OPT_NAIVE_UTC, OPT_SERIALIZE_NUMPY, OPT_SERIALIZE_UUID
1455
from orjson import dumps as _encode_json
1556
from orjson import loads as decode_json # type: ignore[no-redef,assignment]
1657

17-
def encode_json(data: Any) -> str:
18-
return _encode_json(data).decode("utf-8") # type: ignore[no-any-return]
58+
def encode_json(data: Any) -> str: # pragma: no cover
59+
return _encode_json(
60+
data, default=_type_to_string, option=OPT_SERIALIZE_NUMPY | OPT_NAIVE_UTC | OPT_SERIALIZE_UUID
61+
).decode("utf-8") # type: ignore[no-any-return]
1962

2063
except ImportError:
2164
from json import dumps as encode_json # type: ignore[assignment] # noqa: F401
2265
from json import loads as decode_json # type: ignore[assignment] # noqa: F401
66+
67+
68+
def convert_datetime_to_gmt_iso(dt: datetime.datetime) -> str: # pragma: no cover
69+
"""Handle datetime serialization for nested timestamps."""
70+
if not dt.tzinfo:
71+
dt = dt.replace(tzinfo=datetime.timezone.utc)
72+
return dt.isoformat().replace("+00:00", "Z")
73+
74+
75+
def convert_date_to_iso(dt: datetime.date) -> str: # pragma: no cover
76+
"""Handle datetime serialization for nested timestamps."""
77+
return dt.isoformat()

advanced_alchemy/extensions/litestar/plugins/init/plugin.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
OrderBy,
2222
SearchFilter,
2323
)
24-
from advanced_alchemy.service import OffsetPagination
24+
from advanced_alchemy.service import ModelDictListT, ModelDictT, ModelDTOT, ModelOrRowMappingT, ModelT, OffsetPagination
2525

2626
if TYPE_CHECKING:
2727
from click import Group
@@ -47,6 +47,11 @@
4747
"Dependency": Dependency,
4848
"DTOData": DTOData,
4949
"Sequence": Sequence,
50+
"ModelT": ModelT,
51+
"ModelDictT": ModelDictT,
52+
"ModelDTOT": ModelDTOT,
53+
"ModelDictListT": ModelDictListT,
54+
"ModelOrRowMappingT": ModelOrRowMappingT,
5055
}
5156

5257

advanced_alchemy/service/_typing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
T_co = TypeVar("T_co", covariant=True)
1919

2020
try:
21-
from pydantic import BaseModel, FailFast, TypeAdapter
21+
from pydantic import BaseModel, FailFast, TypeAdapter # pyright: ignore[reportGeneralTypeIssues]
2222

2323
PYDANTIC_INSTALLED = True
2424
except ImportError:
@@ -72,8 +72,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
7272
from msgspec import (
7373
UNSET,
7474
Struct,
75-
UnsetType, # pyright: ignore[reportAssignmentType]
76-
convert,
75+
UnsetType, # pyright: ignore[reportAssignmentType,reportGeneralTypeIssues]
76+
convert, # pyright: ignore[reportGeneralTypeIssues]
7777
)
7878

7979
MSGSPEC_INSTALLED: bool = True

advanced_alchemy/service/_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import datetime
88
from collections.abc import Sequence
9+
from enum import Enum
910
from functools import partial
1011
from pathlib import Path, PurePath
1112
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, overload
@@ -38,6 +39,7 @@
3839
(lambda x: x is datetime.datetime, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
3940
(lambda x: x is datetime.date, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
4041
(lambda x: x is datetime.time, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
42+
(lambda x: x is Enum, lambda t, v: t(v.value)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
4143
]
4244

4345

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
if TYPE_CHECKING:
1616
from typing import Any
1717

18-
from sphinx.addnodes import document
18+
from sphinx.addnodes import document # type: ignore[attr-defined,unused-ignore]
1919
from sphinx.application import Sphinx
2020

2121
# -- Environmental Data ------------------------------------------------------

docs/usage/frameworks/fastapi.rst

Lines changed: 38 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,21 @@ Configure SQLAlchemy with FastAPI:
1111

1212
.. code-block:: python
1313
14-
from contextlib import asynccontextmanager
1514
from typing import AsyncGenerator
1615
1716
from fastapi import FastAPI
1817
19-
from advanced_alchemy.config import AsyncSessionConfig, SQLAlchemyAsyncConfig
20-
from advanced_alchemy.base import metadata_registry
21-
from advanced_alchemy.extensions.starlette import StarletteAdvancedAlchemy
18+
from advanced_alchemy.extensions.fastapi import AdvancedAlchemy, AsyncSessionConfig, SQLAlchemyAsyncConfig
2219
23-
session_config = AsyncSessionConfig(expire_on_commit=False)
2420
sqlalchemy_config = SQLAlchemyAsyncConfig(
2521
connection_string="sqlite+aiosqlite:///test.sqlite",
26-
session_config=session_config
22+
session_config=AsyncSessionConfig(expire_on_commit=False),
23+
create_all=True,
24+
commit_mode="autocommit",
2725
)
2826
29-
@asynccontextmanager
30-
async def on_lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
31-
"""Initializes the database."""
32-
metadata = metadata_registry.get(sqlalchemy_config.bind_key)
33-
if sqlalchemy_config.create_all:
34-
async with sqlalchemy_config.get_engine().begin() as conn:
35-
await conn.run_sync(metadata.create_all)
36-
yield
37-
38-
app = FastAPI(lifespan=on_lifespan)
39-
alchemy = StarletteAdvancedAlchemy(config=sqlalchemy_config, app=app)
27+
app = FastAPI()
28+
alchemy = AdvancedAlchemy(config=sqlalchemy_config, app=app)
4029
4130
Models and Schemas
4231
------------------
@@ -88,38 +77,40 @@ Create repository and service classes:
8877

8978
.. code-block:: python
9079
80+
from typing import Annotated, AsyncGenerator, Optional
81+
82+
from advanced_alchemy.extensions.fastapi import repository, service
83+
from fastapi import Depends
9184
from sqlalchemy.ext.asyncio import AsyncSession
92-
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
93-
from advanced_alchemy.service import SQLAlchemyAsyncRepositoryService
94-
from typing import AsyncGenerator
9585
96-
class AuthorRepository(SQLAlchemyAsyncRepository[AuthorModel]):
97-
"""Author repository."""
98-
model_type = AuthorModel
9986
100-
class AuthorService(SQLAlchemyAsyncRepositoryService[AuthorModel]):
87+
class AuthorService(service.SQLAlchemyAsyncRepositoryService[AuthorModel]):
10188
"""Author service."""
102-
repository_type = AuthorRepository
10389
104-
async def provide_authors_service(
105-
db_session: Annotated[AsyncSession, Depends(provide_db_session)],
106-
) -> AsyncGenerator[AuthorService, None]:
107-
"""This provides the default Authors repository."""
108-
async with AuthorService.new(session=db_session) as service:
109-
yield service
90+
class Repo(repository.SQLAlchemyAsyncRepository[AuthorModel]):
91+
"""Author repository."""
92+
model_type = AuthorModel
93+
94+
repository_type = Repo
95+
11096
11197
Dependency Injection
11298
--------------------
11399

114-
Set up dependency injection for the database session:
100+
Set up dependency injected into the request context.
115101

116102
.. code-block:: python
117103
118104
from fastapi import Request
119105
120-
async def provide_db_session(request: Request) -> AsyncSession:
121-
"""Provide a DB session."""
122-
return alchemy.get_session(request) # this is the `StarletteAdvancedAlchemy` object
106+
DatabaseSession = Annotated[AsyncSession, Depends(alchemy.provide_session())]
107+
Authors = Annotated[AuthorService, Depends(provide_authors_service)]
108+
109+
async def provide_authors_service(db_session: DatabaseSession) -> AsyncGenerator[AuthorService, None]:
110+
"""This provides the default Authors repository."""
111+
async with AuthorService.new(session=db_session) as service:
112+
yield service
113+
123114
124115
Controllers
125116
-----------
@@ -130,32 +121,31 @@ Create controllers using the service:
130121
131122
from fastapi import APIRouter, Depends
132123
from uuid import UUID
133-
from advanced_alchemy.filters import LimitOffset
134-
from advanced_alchemy.service import OffsetPagination
124+
from advanced_alchemy.extensions.fastapi import filters
135125
136126
author_router = APIRouter()
137127
138-
@author_router.get(path="/authors", response_model=OffsetPagination[Author])
128+
@author_router.get(path="/authors", response_model=filters.OffsetPagination[Author])
139129
async def list_authors(
140-
authors_service: Annotated[AuthorService, Depends(provide_authors_service)],
141-
limit_offset: Annotated[LimitOffset, Depends(provide_limit_offset_pagination)],
142-
) -> OffsetPagination[AuthorModel]:
130+
authors_service: Authors,
131+
limit_offset: Annotated[filters.LimitOffset, Depends(provide_limit_offset_pagination)],
132+
) -> filters.OffsetPagination[AuthorModel]:
143133
"""List authors."""
144134
results, total = await authors_service.list_and_count(limit_offset)
145135
return authors_service.to_schema(results, total, filters=[limit_offset])
146136
147137
@author_router.post(path="/authors", response_model=Author)
148138
async def create_author(
149-
authors_service: Annotated[AuthorService, Depends(provide_authors_service)],
139+
authors_service: Authors,
150140
data: AuthorCreate,
151141
) -> AuthorModel:
152142
"""Create a new author."""
153-
obj = await authors_service.create(data.model_dump(exclude_unset=True, exclude_none=True), auto_commit=True)
143+
obj = await authors_service.create(data)
154144
return authors_service.to_schema(obj)
155145
156146
@author_router.get(path="/authors/{author_id}", response_model=Author)
157147
async def get_author(
158-
authors_service: Annotated[AuthorService, Depends(provide_authors_service)],
148+
authors_service: Authors,
159149
author_id: UUID,
160150
) -> AuthorModel:
161151
"""Get an existing author."""
@@ -164,25 +154,21 @@ Create controllers using the service:
164154
165155
@author_router.patch(path="/authors/{author_id}", response_model=Author)
166156
async def update_author(
167-
authors_service: Annotated[AuthorService, Depends(provide_authors_service)],
157+
authors_service: Authors,
168158
data: AuthorUpdate,
169159
author_id: UUID,
170160
) -> AuthorModel:
171161
"""Update an author."""
172-
obj = await authors_service.update(
173-
data.model_dump(exclude_unset=True, exclude_none=True),
174-
item_id=author_id,
175-
auto_commit=True,
176-
)
162+
obj = await authors_service.update(data, item_id=author_id)
177163
return authors_service.to_schema(obj)
178164
179165
@author_router.delete(path="/authors/{author_id}")
180166
async def delete_author(
181-
authors_service: Annotated[AuthorService, Depends(provide_authors_service)],
167+
authors_service: Authors,
182168
author_id: UUID,
183169
) -> None:
184170
"""Delete an author from the system."""
185-
_ = await authors_service.delete(author_id, auto_commit=True)
171+
_ = await authors_service.delete(author_id)
186172
187173
Application Configuration
188174
-------------------------

examples/standalone_json.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import annotations
33

44
import asyncio
5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, Any
66

77
from sqlalchemy.ext.asyncio import create_async_engine
88

@@ -17,7 +17,7 @@
1717
class Item(UUIDBase):
1818
name: Mapped[str]
1919
# using ``Mapped[dict]`` with an AA provided base will map it to ``JSONB``
20-
data: Mapped[dict]
20+
data: Mapped[dict[str, Any]]
2121

2222

2323
class ItemRepository(SQLAlchemyAsyncRepository[Item]):

0 commit comments

Comments
 (0)