Skip to content

Commit b732577

Browse files
committed
🐛 fix(sqla): close scoped session by processor
1 parent 0a45835 commit b732577

File tree

1 file changed

+34
-18
lines changed

1 file changed

+34
-18
lines changed

nonebot_plugin_orm/__init__.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,22 @@
33
import sys
44
import logging
55
from typing import Any
6-
from asyncio import gather
76
from argparse import Namespace
8-
from operator import methodcaller
9-
from collections.abc import AsyncGenerator
10-
from functools import wraps, partial, lru_cache
7+
from contextlib import suppress
8+
from functools import wraps, lru_cache
119

1210
import click
1311
from nonebot.rule import Rule
12+
from nonebot.adapters import Event
1413
import sqlalchemy.ext.asyncio as sa_async
1514
from nonebot.permission import Permission
16-
from sqlalchemy.util import greenlet_spawn
15+
from sqlalchemy import URL, Table, MetaData
1716
from nonebot.params import Depends, DefaultParam
1817
from nonebot.plugin import Plugin, PluginMetadata
19-
from nonebot.matcher import Matcher, current_matcher
20-
from sqlalchemy import URL, Table, MetaData, make_url
18+
from sqlalchemy.util import ScopedRegistry, greenlet_spawn
2119
from sqlalchemy.log import Identified, _qual_logger_name_for_cls
20+
from nonebot.message import run_postprocessor, event_postprocessor
21+
from nonebot.matcher import Matcher, current_event, current_matcher
2222
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
2323
from nonebot import logger, require, get_driver, get_plugin_by_module_name
2424

@@ -65,7 +65,8 @@
6565
_engines: dict[str, AsyncEngine]
6666
_metadatas: dict[str, MetaData]
6767
_plugins: dict[str, Plugin]
68-
_session_factory: sa_async.async_sessionmaker[AsyncSession]
68+
_session_factory: sa_async.async_sessionmaker[sa_async.AsyncSession]
69+
_scoped_sessions: ScopedRegistry[sa_async.async_scoped_session[sa_async.AsyncSession]]
6970

7071
_data_dir = get_data_dir(__plugin_meta__.name)
7172
_driver = get_driver()
@@ -93,7 +94,7 @@ async def init_orm() -> None:
9394

9495

9596
def _init_orm():
96-
global _session_factory
97+
global _session_factory, _scoped_sessions
9798

9899
_init_engines()
99100
_init_table()
@@ -103,6 +104,12 @@ def _init_orm():
103104
**plugin_config.sqlalchemy_session_options,
104105
}
105106
)
107+
_scoped_sessions = ScopedRegistry(
108+
lambda: sa_async.async_scoped_session(
109+
_session_factory, lambda: current_matcher.get(None)
110+
),
111+
lambda: id(current_event.get(None)),
112+
)
106113

107114

108115
@wraps(lambda: None) # NOTE: for dependency injection
@@ -116,25 +123,34 @@ def get_session(**local_kw: Any) -> sa_async.AsyncSession:
116123
AsyncSession = Annotated[sa_async.AsyncSession, Depends(get_session)]
117124

118125

119-
async def get_scoped_session() -> (
120-
AsyncGenerator[sa_async.async_scoped_session[AsyncSession], None]
121-
):
126+
async def get_scoped_session() -> sa_async.async_scoped_session[sa_async.AsyncSession]:
122127
try:
123-
scoped_session = async_scoped_session(
124-
_session_factory, scopefunc=partial(current_matcher.get, None)
125-
)
126-
yield scoped_session
128+
return _scoped_sessions()
127129
except NameError:
128130
raise RuntimeError("nonebot-plugin-orm 未初始化") from None
129131

130-
await gather(*map(methodcaller("close"), scoped_session.registry.registry.values()))
131-
132132

133133
async_scoped_session = Annotated[
134134
sa_async.async_scoped_session[sa_async.AsyncSession], Depends(get_scoped_session)
135135
]
136136

137137

138+
@event_postprocessor
139+
def _clear_scoped_session(event: Event) -> None:
140+
with suppress(KeyError):
141+
del _scoped_sessions.registry[id(event)]
142+
143+
144+
@run_postprocessor
145+
async def _close_scoped_session(event: Event, matcher: Matcher) -> None:
146+
with suppress(KeyError):
147+
session: sa_async.AsyncSession = _scoped_sessions.registry[
148+
id(event)
149+
].registry.registry[matcher]
150+
del _scoped_sessions.registry[id(event)].registry.registry[matcher]
151+
await session.close()
152+
153+
138154
def _create_engine(engine: str | URL | dict[str, Any] | AsyncEngine) -> AsyncEngine:
139155
if isinstance(engine, AsyncEngine):
140156
return engine

0 commit comments

Comments
 (0)