44import logging
55from asyncio import gather
66from operator import methodcaller
7- from functools import wraps , partial
87from typing import Any , AsyncGenerator
8+ from functools import wraps , partial , lru_cache
99
1010from nonebot .params import Depends
11- from nonebot .plugin import PluginMetadata
12- import sqlalchemy .ext .asyncio as sa_asyncio
11+ import sqlalchemy .ext .asyncio as sa_async
1312from sqlalchemy .util import greenlet_spawn
1413from nonebot .matcher import current_matcher
15- from nonebot import logger , require , get_driver
14+ from nonebot . plugin import Plugin , PluginMetadata
1615from sqlalchemy import URL , Table , MetaData , make_url
1716from sqlalchemy .ext .asyncio import AsyncEngine , create_async_engine
17+ from nonebot import logger , require , get_driver , get_plugin_by_module_name
1818
1919from . import migrate
2020from .config import Config
5555 "merge" ,
5656 "upgrade" ,
5757 "downgrade" ,
58+ "sync" ,
5859 "show" ,
5960 "history" ,
6061 "heads" ,
7778_binds : dict [type [Model ], AsyncEngine ]
7879_engines : dict [str , AsyncEngine ]
7980_metadatas : dict [str , MetaData ]
80- _session_factory : sa_asyncio .async_sessionmaker [AsyncSession ]
81+ _plugins : dict [str , Plugin ]
82+ _session_factory : sa_async .async_sessionmaker [AsyncSession ]
8183
8284_driver = get_driver ()
8385
@@ -90,34 +92,33 @@ async def init_orm() -> None:
9092 if plugin_config .alembic_startup_check :
9193 await greenlet_spawn (migrate .check , alembic_config )
9294 else :
93- logger .warning ("跳过启动检查,直接创建所有表并标记数据库为最新修订版本" )
94- await migrate ._upgrade_fast (alembic_config )
95- await greenlet_spawn (migrate .stamp , alembic_config )
95+ logger .warning ("跳过启动检查,正在同步数据库模式..." )
96+ await greenlet_spawn (migrate .sync , alembic_config )
9697
9798
9899def _init_orm ():
99100 global _session_factory
100101
101102 _init_engines ()
102103 _init_table ()
103- _session_factory = sa_asyncio .async_sessionmaker (
104+ _session_factory = sa_async .async_sessionmaker (
104105 _engines ["" ], binds = _binds , ** plugin_config .sqlalchemy_session_options
105106 )
106107
107108
108109@wraps (lambda : None ) # NOTE: for dependency injection
109- def get_session (** local_kw : Any ) -> sa_asyncio .AsyncSession :
110+ def get_session (** local_kw : Any ) -> sa_async .AsyncSession :
110111 try :
111112 return _session_factory (** local_kw )
112113 except NameError :
113114 raise RuntimeError ("nonebot-plugin-orm 未初始化" ) from None
114115
115116
116- AsyncSession = Annotated [sa_asyncio .AsyncSession , Depends (get_session )]
117+ AsyncSession = Annotated [sa_async .AsyncSession , Depends (get_session )]
117118
118119
119120async def get_scoped_session () -> (
120- AsyncGenerator [sa_asyncio .async_scoped_session [AsyncSession ], None ]
121+ AsyncGenerator [sa_async .async_scoped_session [AsyncSession ], None ]
121122):
122123 try :
123124 scoped_session = async_scoped_session (
@@ -131,7 +132,7 @@ async def get_scoped_session() -> (
131132
132133
133134async_scoped_session = Annotated [
134- sa_asyncio .async_scoped_session [AsyncSession ], Depends (get_scoped_session )
135+ sa_async .async_scoped_session [sa_async . AsyncSession ], Depends (get_scoped_session )
135136]
136137
137138
@@ -186,20 +187,21 @@ def _init_engines():
186187
187188
188189def _init_table ():
189- global _binds , _metadatas
190+ global _binds , _metadatas , _plugins
190191
191192 _binds = {}
193+ _plugins = {}
192194
193- if len (_engines ) == 1 : # NOTE: common case: only default engine
194- _metadatas = {"" : Model .metadata }
195- return
196-
195+ _get_plugin_by_module_name = lru_cache (None )(get_plugin_by_module_name )
197196 for model in Model .__subclasses__ ():
198197 table : Table | None = getattr (model , "__table__" , None )
199198
200199 if table is None or (bind_key := table .info .get ("bind_key" )) is None :
201200 continue
202201
202+ if plugin := _get_plugin_by_module_name (model .__module__ ):
203+ _plugins [plugin .name ] = plugin
204+
203205 _binds [model ] = _engines .get (bind_key , _engines ["" ])
204206 table .to_metadata (_metadatas .get (bind_key , _metadatas ["" ]))
205207
0 commit comments