33import sys
44from inspect import Parameter , Signature
55from typing import TYPE_CHECKING , Any , ClassVar
6+ from typing_extensions import Self , Unpack , Annotated
67
7- from sqlalchemy import Table , MetaData
8+ from nonebot . params import Depends
89from nonebot import get_plugin_by_module_name
10+ from sqlalchemy import Table , MetaData , select
11+ from pydantic .typing import get_args , get_origin
912from sqlalchemy .orm import Mapped , DeclarativeBase
1013
1114from .utils import DependsInner , get_annotations
1215
1316if sys .version_info >= (3 , 9 ):
14- from typing import Annotated , get_args , get_origin
17+ from typing import Annotated
1518else :
16- from typing_extensions import Annotated , get_args , get_origin
19+ from typing_extensions import Annotated
1720
1821__all__ = ("Model" ,)
1922
2730}
2831
2932
33+ class Model (DeclarativeBase ):
34+ metadata = MetaData (naming_convention = NAMING_CONVENTION )
35+
36+ if TYPE_CHECKING :
37+ __args__ : ClassVar [tuple [type [Self ], Unpack [tuple [Any , ...]]]]
38+ __origin__ : type [Annotated ]
39+
40+ __table__ : ClassVar [Table ]
41+ __bind_key__ : ClassVar [str ]
42+
43+ def __init_subclass__ (cls ) -> None :
44+ _setup_di (cls )
45+ _setup_tablename (cls )
46+
47+ super ().__init_subclass__ ()
48+
49+ if not hasattr (cls , "__table__" ):
50+ return
51+
52+ _setup_bind (cls )
53+
54+
3055def _setup_di (cls : type [Model ]) -> None :
3156 """Get signature for NoneBot's dependency injection,
3257 and set annotations for SQLAlchemy declarative class.
3358 """
59+ from . import async_scoped_session
3460
35- parameters : list [Parameter ] = []
61+ parameters : list [Parameter ] = [
62+ Parameter (
63+ "__session__" , Parameter .KEYWORD_ONLY , annotation = async_scoped_session
64+ )
65+ ]
3666
3767 annotations : dict [str , Any ] = {}
3868 for base in reversed (cls .__mro__ ):
@@ -41,13 +71,13 @@ def _setup_di(cls: type[Model]) -> None:
4171 for name , type_annotation in annotations .items ():
4272 # Check if the attribute is both a dependent and a mapped column
4373 depends_inner = None
44- if isinstance ( get_origin (type_annotation ), Annotated ) :
74+ if get_origin (type_annotation ) is Annotated :
4575 (type_annotation , * extra_args ) = get_args (type_annotation )
4676 depends_inner = next (
4777 (x for x in extra_args if isinstance (x , DependsInner )), None
4878 )
4979
50- if not isinstance ( get_origin (type_annotation ), Mapped ) :
80+ if get_origin (type_annotation ) is not Mapped :
5181 continue
5282
5383 default = getattr (cls , name , Signature .empty )
@@ -71,7 +101,14 @@ def _setup_di(cls: type[Model]) -> None:
71101 if default is not Signature .empty and not isinstance (default , Mapped ):
72102 delattr (cls , name )
73103
74- cls .__signature__ = Signature (parameters )
104+ async def dependency (
105+ * , __session__ : async_scoped_session , ** kwargs : Any
106+ ) -> Model | None :
107+ return await __session__ .scalar (select (cls ).filter_by (** kwargs ))
108+
109+ dependency .__signature__ = Signature (parameters )
110+ cls .__args__ = (Model , Depends (dependency ))
111+ cls .__origin__ = Annotated
75112
76113
77114def _setup_tablename (cls : type [Model ]) -> None :
@@ -95,23 +132,3 @@ def _setup_bind(cls: type[Model]) -> None:
95132 bind_key = ""
96133
97134 cls .__table__ .info ["bind_key" ] = bind_key
98-
99-
100- class Model (DeclarativeBase ):
101- metadata = MetaData (naming_convention = NAMING_CONVENTION )
102-
103- if TYPE_CHECKING :
104- __table__ : ClassVar [Table ]
105- __bind_key__ : ClassVar [str ]
106- __signature__ : ClassVar [Signature ]
107-
108- def __init_subclass__ (cls ) -> None :
109- _setup_di (cls )
110- _setup_tablename (cls )
111-
112- super ().__init_subclass__ ()
113-
114- if not hasattr (cls , "__table__" ):
115- return
116-
117- _setup_bind (cls )
0 commit comments