Skip to content

Commit 689f02c

Browse files
committed
Add AsyncSQLModel and AwaitableField
1 parent 51df778 commit 689f02c

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ httpx = "0.24.1"
5454
dirty-equals = "^0.6.0"
5555
typer-cli = "^0.0.13"
5656
mkdocs-markdownextradata-plugin = ">=0.1.7,<0.3.0"
57+
pytest-asyncio = "0.21.1"
58+
aiosqlite = "0.19.0"
5759

5860
[build-system]
5961
requires = ["poetry-core"]
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from typing import Any, ClassVar, Coroutine, Dict, Tuple, Type
2+
3+
from pydantic._internal._repr import Representation
4+
from sqlalchemy.util.concurrency import greenlet_spawn
5+
6+
from ... import SQLModel
7+
from ..._compat import get_annotations
8+
from ...main import SQLModelMetaclass
9+
10+
11+
class AwaitableFieldInfo(Representation):
12+
def __init__(self, *, field: str):
13+
self.field = field
14+
15+
16+
def AwaitableField(*, field: str) -> Any:
17+
return AwaitableFieldInfo(field=field)
18+
19+
20+
class AsyncSQLModelMetaclass(SQLModelMetaclass):
21+
__async_sqlmodel_awaitable_fields__: Dict[str, AwaitableFieldInfo]
22+
23+
def __new__(
24+
cls,
25+
name: str,
26+
bases: Tuple[Type[Any], ...],
27+
class_dict: Dict[str, Any],
28+
**kwargs: Any
29+
) -> Any:
30+
awaitable_fields: Dict[str, AwaitableFieldInfo] = {}
31+
dict_for_sqlmodel = {}
32+
original_annotations = get_annotations(class_dict)
33+
sqlmodel_annotations = {}
34+
awaitable_fields_annotations = {}
35+
for k, v in class_dict.items():
36+
if isinstance(v, AwaitableFieldInfo):
37+
awaitable_fields[k] = v
38+
else:
39+
dict_for_sqlmodel[k] = v
40+
for k, v in original_annotations.items():
41+
if k in awaitable_fields:
42+
awaitable_fields_annotations[k] = v
43+
else:
44+
sqlmodel_annotations[k] = v
45+
46+
dict_used = {
47+
**dict_for_sqlmodel,
48+
"__async_sqlmodel_awaitable_fields__": awaitable_fields,
49+
"__annotations__": sqlmodel_annotations,
50+
}
51+
return super().__new__(cls, name, bases, dict_used, **kwargs)
52+
53+
def __init__(
54+
cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any
55+
) -> None:
56+
for field_name, field_info in cls.__async_sqlmodel_awaitable_fields__.items():
57+
58+
def get_awaitable_field(
59+
self, field: str = field_info.field
60+
) -> Coroutine[Any, Any, Any]:
61+
return greenlet_spawn(getattr, self, field)
62+
63+
setattr(cls, field_name, property(get_awaitable_field)) # type: ignore
64+
65+
SQLModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
66+
67+
68+
class AsyncSQLModel(SQLModel, metaclass=AsyncSQLModelMetaclass):
69+
__async_sqlmodel_awaitable_fields__: ClassVar[Dict[str, AwaitableFieldInfo]]

0 commit comments

Comments
 (0)