33import sys
44import logging
55from typing import Any
6- from asyncio import gather
76from argparse import Namespace
8- from operator import methodcaller
9- from collections .abc import AsyncGenerator
10- from functools import wraps , partial , lru_cache
7+ from contextlib import suppress
8+ from functools import wraps , lru_cache
119
1210import click
1311from nonebot .rule import Rule
12+ from nonebot .adapters import Event
1413import sqlalchemy .ext .asyncio as sa_async
1514from nonebot .permission import Permission
16- from sqlalchemy . util import greenlet_spawn
15+ from sqlalchemy import URL , Table , MetaData
1716from nonebot .params import Depends , DefaultParam
1817from nonebot .plugin import Plugin , PluginMetadata
19- from nonebot .matcher import Matcher , current_matcher
20- from sqlalchemy import URL , Table , MetaData , make_url
18+ from sqlalchemy .util import ScopedRegistry , greenlet_spawn
2119from sqlalchemy .log import Identified , _qual_logger_name_for_cls
20+ from nonebot .message import run_postprocessor , event_postprocessor
21+ from nonebot .matcher import Matcher , current_event , current_matcher
2222from sqlalchemy .ext .asyncio import AsyncEngine , create_async_engine
2323from nonebot import logger , require , get_driver , get_plugin_by_module_name
2424
6565_engines : dict [str , AsyncEngine ]
6666_metadatas : dict [str , MetaData ]
6767_plugins : dict [str , Plugin ]
68- _session_factory : sa_async .async_sessionmaker [AsyncSession ]
68+ _session_factory : sa_async .async_sessionmaker [sa_async .AsyncSession ]
69+ _scoped_sessions : ScopedRegistry [sa_async .async_scoped_session [sa_async .AsyncSession ]]
6970
7071_data_dir = get_data_dir (__plugin_meta__ .name )
7172_driver = get_driver ()
@@ -93,7 +94,7 @@ async def init_orm() -> None:
9394
9495
9596def _init_orm ():
96- global _session_factory
97+ global _session_factory , _scoped_sessions
9798
9899 _init_engines ()
99100 _init_table ()
@@ -103,6 +104,12 @@ def _init_orm():
103104 ** plugin_config .sqlalchemy_session_options ,
104105 }
105106 )
107+ _scoped_sessions = ScopedRegistry (
108+ lambda : sa_async .async_scoped_session (
109+ _session_factory , lambda : current_matcher .get (None )
110+ ),
111+ lambda : id (current_event .get (None )),
112+ )
106113
107114
108115@wraps (lambda : None ) # NOTE: for dependency injection
@@ -116,25 +123,34 @@ def get_session(**local_kw: Any) -> sa_async.AsyncSession:
116123AsyncSession = Annotated [sa_async .AsyncSession , Depends (get_session )]
117124
118125
119- async def get_scoped_session () -> (
120- AsyncGenerator [sa_async .async_scoped_session [AsyncSession ], None ]
121- ):
126+ async def get_scoped_session () -> sa_async .async_scoped_session [sa_async .AsyncSession ]:
122127 try :
123- scoped_session = async_scoped_session (
124- _session_factory , scopefunc = partial (current_matcher .get , None )
125- )
126- yield scoped_session
128+ return _scoped_sessions ()
127129 except NameError :
128130 raise RuntimeError ("nonebot-plugin-orm 未初始化" ) from None
129131
130- await gather (* map (methodcaller ("close" ), scoped_session .registry .registry .values ()))
131-
132132
133133async_scoped_session = Annotated [
134134 sa_async .async_scoped_session [sa_async .AsyncSession ], Depends (get_scoped_session )
135135]
136136
137137
138+ @event_postprocessor
139+ def _clear_scoped_session (event : Event ) -> None :
140+ with suppress (KeyError ):
141+ del _scoped_sessions .registry [id (event )]
142+
143+
144+ @run_postprocessor
145+ async def _close_scoped_session (event : Event , matcher : Matcher ) -> None :
146+ with suppress (KeyError ):
147+ session : sa_async .AsyncSession = _scoped_sessions .registry [
148+ id (event )
149+ ].registry .registry [matcher ]
150+ del _scoped_sessions .registry [id (event )].registry .registry [matcher ]
151+ await session .close ()
152+
153+
138154def _create_engine (engine : str | URL | dict [str , Any ] | AsyncEngine ) -> AsyncEngine :
139155 if isinstance (engine , AsyncEngine ):
140156 return engine
0 commit comments