33import sys
44from inspect import Parameter , Signature
55from typing import TYPE_CHECKING , Any , ClassVar
6- from typing_extensions import Self , Unpack , Annotated
76
87from nonebot .params import Depends
98from nonebot import get_plugin_by_module_name
109from sqlalchemy import Table , MetaData , select
1110from pydantic .typing import get_args , get_origin
1211from sqlalchemy .orm import Mapped , DeclarativeBase
12+ from sqlalchemy .orm .decl_api import DeclarativeAttributeIntercept
1313
1414from .utils import DependsInner , get_annotations
1515
3030}
3131
3232
33- class Model (DeclarativeBase ):
33+ class ModelMeta (DeclarativeAttributeIntercept ):
34+ if TYPE_CHECKING :
35+ __signature__ : Signature
36+
37+ def __new__ (
38+ mcs ,
39+ name : str ,
40+ bases : tuple [type , ...],
41+ namespace : dict [str , Any ],
42+ ** kwargs : Any ,
43+ ) -> ModelMeta :
44+ from . import async_scoped_session
45+
46+ cls : ModelMeta = super ().__new__ (mcs , name , bases , namespace , ** kwargs )
47+
48+ if not (signature := getattr (cls , "__signature__" , None )):
49+ return cls
50+
51+ async def dependency (
52+ * , __session : async_scoped_session , ** kwargs : Any
53+ ) -> ModelMeta | None :
54+ return await __session .scalar (select (cls ).filter_by (** kwargs ))
55+
56+ dependency .__signature__ = Signature (
57+ (
58+ Parameter (
59+ "_ModelMeta__session" ,
60+ Parameter .KEYWORD_ONLY ,
61+ annotation = async_scoped_session ,
62+ ),
63+ * signature .parameters .values (),
64+ )
65+ )
66+
67+ return Annotated [cls , Depends (dependency )]
68+
69+
70+ class Model (DeclarativeBase , metaclass = ModelMeta ):
3471 metadata = MetaData (naming_convention = NAMING_CONVENTION )
3572
3673 if TYPE_CHECKING :
37- __args__ : ClassVar [tuple [type [Self ], Unpack [tuple [Any , ...]]]]
38- __origin__ : type [Annotated ]
39-
4074 __table__ : ClassVar [Table ]
4175 __bind_key__ : ClassVar [str ]
4276
@@ -58,11 +92,7 @@ def _setup_di(cls: type[Model]) -> None:
5892 """
5993 from . import async_scoped_session
6094
61- parameters : list [Parameter ] = [
62- Parameter (
63- "__session__" , Parameter .KEYWORD_ONLY , annotation = async_scoped_session
64- )
65- ]
95+ parameters : list [Parameter ] = []
6696
6797 annotations : dict [str , Any ] = {}
6898 for base in reversed (cls .__mro__ ):
@@ -101,14 +131,7 @@ def _setup_di(cls: type[Model]) -> None:
101131 if default is not Signature .empty and not isinstance (default , Mapped ):
102132 delattr (cls , name )
103133
104- async def dependency (
105- * , __session__ : async_scoped_session , ** kwargs : Any
106- ) -> Model | None :
107- return await __session__ .scalar (select (cls ).filter_by (** kwargs ))
108-
109- dependency .__signature__ = Signature (parameters )
110- cls .__args__ = (Model , Depends (dependency ))
111- cls .__origin__ = Annotated
134+ cls .__signature__ = Signature (parameters )
112135
113136
114137def _setup_tablename (cls : type [Model ]) -> None :
0 commit comments