Skip to content

Commit 10c8df6

Browse files
committed
✨ feat(sqla): Model support dependency injection
1 parent d46f5ce commit 10c8df6

File tree

2 files changed

+74
-40
lines changed

2 files changed

+74
-40
lines changed

nonebot_plugin_orm/__init__.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,32 @@
11
from __future__ import annotations
22

3+
import sys
34
import logging
4-
from typing import Any
5+
from asyncio import gather
6+
from operator import methodcaller
57
from functools import wraps, partial
8+
from typing import Any, AsyncGenerator
69

10+
from nonebot.params import Depends
711
from nonebot.log import LoguruHandler
812
from nonebot.plugin import PluginMetadata
13+
import sqlalchemy.ext.asyncio as sa_asyncio
914
from sqlalchemy.util import greenlet_spawn
1015
from nonebot.matcher import current_matcher
1116
from nonebot import logger, require, get_driver
1217
from sqlalchemy import URL, Table, MetaData, make_url
13-
from sqlalchemy.ext.asyncio import (
14-
AsyncEngine,
15-
AsyncSession,
16-
async_sessionmaker,
17-
create_async_engine,
18-
async_scoped_session,
19-
)
18+
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
2019

2120
from . import migrate
2221
from .model import Model
2322
from .utils import StreamToLogger
2423
from .config import Config, plugin_config
2524

25+
if sys.version_info >= (3, 10):
26+
from typing import Annotated
27+
else:
28+
from typing_extensions import Annotated
29+
2630
__all__ = (
2731
# __init__
2832
"init_orm",
@@ -75,7 +79,7 @@
7579
_binds: dict[type[Model], AsyncEngine]
7680
_engines: dict[str, AsyncEngine]
7781
_metadatas: dict[str, MetaData]
78-
_session_factory: async_sessionmaker[AsyncSession]
82+
_session_factory: sa_asyncio.async_sessionmaker[AsyncSession]
7983

8084
_driver = get_driver()
8185

@@ -86,7 +90,7 @@ async def init_orm():
8690

8791
_init_engines()
8892
_init_table()
89-
_session_factory = async_sessionmaker(
93+
_session_factory = sa_asyncio.async_sessionmaker(
9094
_engines[""], binds=_binds, **plugin_config.sqlalchemy_session_options
9195
)
9296

@@ -100,21 +104,34 @@ async def init_orm():
100104

101105

102106
@wraps(lambda: None) # NOTE: for dependency injection
103-
def get_session(**local_kw: Any) -> AsyncSession:
107+
def get_session(**local_kw: Any) -> sa_asyncio.AsyncSession:
104108
try:
105109
return _session_factory(**local_kw)
106110
except NameError:
107111
raise RuntimeError("nonebot-plugin-orm 未初始化") from None
108112

109113

110-
def get_scoped_session() -> async_scoped_session[AsyncSession]:
114+
AsyncSession = Annotated[sa_asyncio.AsyncSession, Depends(get_session)]
115+
116+
117+
async def get_scoped_session() -> (
118+
AsyncGenerator[sa_asyncio.async_scoped_session[AsyncSession], None]
119+
):
111120
try:
112-
return async_scoped_session(
121+
scoped_session = async_scoped_session(
113122
_session_factory, scopefunc=partial(current_matcher.get, None)
114123
)
124+
yield scoped_session
115125
except NameError:
116126
raise RuntimeError("nonebot-plugin-orm 未初始化") from None
117127

128+
await gather(*map(methodcaller("close"), scoped_session.registry.registry.values()))
129+
130+
131+
async_scoped_session = Annotated[
132+
sa_asyncio.async_scoped_session[AsyncSession], Depends(get_scoped_session)
133+
]
134+
118135

119136
def _create_engine(engine: str | URL | dict[str, Any] | AsyncEngine) -> AsyncEngine:
120137
if isinstance(engine, AsyncEngine):

nonebot_plugin_orm/model.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,20 @@
33
import sys
44
from inspect import Parameter, Signature
55
from typing import TYPE_CHECKING, Any, ClassVar
6+
from typing_extensions import Self, Unpack, Annotated
67

7-
from sqlalchemy import Table, MetaData
8+
from nonebot.params import Depends
89
from nonebot import get_plugin_by_module_name
10+
from sqlalchemy import Table, MetaData, select
11+
from pydantic.typing import get_args, get_origin
912
from sqlalchemy.orm import Mapped, DeclarativeBase
1013

1114
from .utils import DependsInner, get_annotations
1215

1316
if sys.version_info >= (3, 9):
14-
from typing import Annotated, get_args, get_origin
17+
from typing import Annotated
1518
else:
16-
from typing_extensions import Annotated, get_args, get_origin
19+
from typing_extensions import Annotated
1720

1821
__all__ = ("Model",)
1922

@@ -27,12 +30,39 @@
2730
}
2831

2932

33+
class Model(DeclarativeBase):
34+
metadata = MetaData(naming_convention=NAMING_CONVENTION)
35+
36+
if TYPE_CHECKING:
37+
__args__: ClassVar[tuple[type[Self], Unpack[tuple[Any, ...]]]]
38+
__origin__: type[Annotated]
39+
40+
__table__: ClassVar[Table]
41+
__bind_key__: ClassVar[str]
42+
43+
def __init_subclass__(cls) -> None:
44+
_setup_di(cls)
45+
_setup_tablename(cls)
46+
47+
super().__init_subclass__()
48+
49+
if not hasattr(cls, "__table__"):
50+
return
51+
52+
_setup_bind(cls)
53+
54+
3055
def _setup_di(cls: type[Model]) -> None:
3156
"""Get signature for NoneBot's dependency injection,
3257
and set annotations for SQLAlchemy declarative class.
3358
"""
59+
from . import async_scoped_session
3460

35-
parameters: list[Parameter] = []
61+
parameters: list[Parameter] = [
62+
Parameter(
63+
"__session__", Parameter.KEYWORD_ONLY, annotation=async_scoped_session
64+
)
65+
]
3666

3767
annotations: dict[str, Any] = {}
3868
for base in reversed(cls.__mro__):
@@ -41,13 +71,13 @@ def _setup_di(cls: type[Model]) -> None:
4171
for name, type_annotation in annotations.items():
4272
# Check if the attribute is both a dependent and a mapped column
4373
depends_inner = None
44-
if isinstance(get_origin(type_annotation), Annotated):
74+
if get_origin(type_annotation) is Annotated:
4575
(type_annotation, *extra_args) = get_args(type_annotation)
4676
depends_inner = next(
4777
(x for x in extra_args if isinstance(x, DependsInner)), None
4878
)
4979

50-
if not isinstance(get_origin(type_annotation), Mapped):
80+
if get_origin(type_annotation) is not Mapped:
5181
continue
5282

5383
default = getattr(cls, name, Signature.empty)
@@ -71,7 +101,14 @@ def _setup_di(cls: type[Model]) -> None:
71101
if default is not Signature.empty and not isinstance(default, Mapped):
72102
delattr(cls, name)
73103

74-
cls.__signature__ = Signature(parameters)
104+
async def dependency(
105+
*, __session__: async_scoped_session, **kwargs: Any
106+
) -> Model | None:
107+
return await __session__.scalar(select(cls).filter_by(**kwargs))
108+
109+
dependency.__signature__ = Signature(parameters)
110+
cls.__args__ = (Model, Depends(dependency))
111+
cls.__origin__ = Annotated
75112

76113

77114
def _setup_tablename(cls: type[Model]) -> None:
@@ -95,23 +132,3 @@ def _setup_bind(cls: type[Model]) -> None:
95132
bind_key = ""
96133

97134
cls.__table__.info["bind_key"] = bind_key
98-
99-
100-
class Model(DeclarativeBase):
101-
metadata = MetaData(naming_convention=NAMING_CONVENTION)
102-
103-
if TYPE_CHECKING:
104-
__table__: ClassVar[Table]
105-
__bind_key__: ClassVar[str]
106-
__signature__: ClassVar[Signature]
107-
108-
def __init_subclass__(cls) -> None:
109-
_setup_di(cls)
110-
_setup_tablename(cls)
111-
112-
super().__init_subclass__()
113-
114-
if not hasattr(cls, "__table__"):
115-
return
116-
117-
_setup_bind(cls)

0 commit comments

Comments
 (0)