Skip to content

Commit 4f1a235

Browse files
committed
🐛 fix(sqla): Mapped[Model]
1 parent ffcb953 commit 4f1a235

File tree

2 files changed

+44
-22
lines changed

2 files changed

+44
-22
lines changed

nonebot_plugin_orm/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
1818

1919
from . import migrate
20-
from .model import Model
21-
from .config import Config, plugin_config
20+
from .config import Config
2221
from .utils import LoguruHandler, StreamToLogger
2322

2423
if sys.version_info >= (3, 10):
@@ -221,9 +220,9 @@ def _init_logger():
221220
l.setLevel(level)
222221

223222

224-
_init_logger()
225-
226223
from .sql import *
227224
from .model import *
228225
from .config import *
229226
from .migrate import *
227+
228+
_init_logger()

nonebot_plugin_orm/model.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import sys
44
from inspect import Parameter, Signature
55
from typing import TYPE_CHECKING, Any, ClassVar
6-
from typing_extensions import Self, Unpack, Annotated
76

87
from nonebot.params import Depends
98
from nonebot import get_plugin_by_module_name
109
from sqlalchemy import Table, MetaData, select
1110
from pydantic.typing import get_args, get_origin
1211
from sqlalchemy.orm import Mapped, DeclarativeBase
12+
from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept
1313

1414
from .utils import DependsInner, get_annotations
1515

@@ -30,13 +30,47 @@
3030
}
3131

3232

33-
class Model(DeclarativeBase):
33+
class ModelMeta(DeclarativeAttributeIntercept):
34+
if TYPE_CHECKING:
35+
__signature__: Signature
36+
37+
def __new__(
38+
mcs,
39+
name: str,
40+
bases: tuple[type, ...],
41+
namespace: dict[str, Any],
42+
**kwargs: Any,
43+
) -> ModelMeta:
44+
from . import async_scoped_session
45+
46+
cls: ModelMeta = super().__new__(mcs, name, bases, namespace, **kwargs)
47+
48+
if not (signature := getattr(cls, "__signature__", None)):
49+
return cls
50+
51+
async def dependency(
52+
*, __session: async_scoped_session, **kwargs: Any
53+
) -> ModelMeta | None:
54+
return await __session.scalar(select(cls).filter_by(**kwargs))
55+
56+
dependency.__signature__ = Signature(
57+
(
58+
Parameter(
59+
"_ModelMeta__session",
60+
Parameter.KEYWORD_ONLY,
61+
annotation=async_scoped_session,
62+
),
63+
*signature.parameters.values(),
64+
)
65+
)
66+
67+
return Annotated[cls, Depends(dependency)]
68+
69+
70+
class Model(DeclarativeBase, metaclass=ModelMeta):
3471
metadata = MetaData(naming_convention=NAMING_CONVENTION)
3572

3673
if TYPE_CHECKING:
37-
__args__: ClassVar[tuple[type[Self], Unpack[tuple[Any, ...]]]]
38-
__origin__: type[Annotated]
39-
4074
__table__: ClassVar[Table]
4175
__bind_key__: ClassVar[str]
4276

@@ -58,11 +92,7 @@ def _setup_di(cls: type[Model]) -> None:
5892
"""
5993
from . import async_scoped_session
6094

61-
parameters: list[Parameter] = [
62-
Parameter(
63-
"__session__", Parameter.KEYWORD_ONLY, annotation=async_scoped_session
64-
)
65-
]
95+
parameters: list[Parameter] = []
6696

6797
annotations: dict[str, Any] = {}
6898
for base in reversed(cls.__mro__):
@@ -101,14 +131,7 @@ def _setup_di(cls: type[Model]) -> None:
101131
if default is not Signature.empty and not isinstance(default, Mapped):
102132
delattr(cls, name)
103133

104-
async def dependency(
105-
*, __session__: async_scoped_session, **kwargs: Any
106-
) -> Model | None:
107-
return await __session__.scalar(select(cls).filter_by(**kwargs))
108-
109-
dependency.__signature__ = Signature(parameters)
110-
cls.__args__ = (Model, Depends(dependency))
111-
cls.__origin__ = Annotated
134+
cls.__signature__ = Signature(parameters)
112135

113136

114137
def _setup_tablename(cls: type[Model]) -> None:

0 commit comments

Comments
 (0)