Skip to content

Commit 477c0f4

Browse files
committed
⚡ perf(sqla): single layer scoped session
1 parent 28326e3 commit 477c0f4

File tree

2 files changed

+11
-31
lines changed

2 files changed

+11
-31
lines changed

nonebot_plugin_orm/__init__.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,18 @@
44
import logging
55
from typing import Any
66
from argparse import Namespace
7-
from contextlib import suppress
87
from functools import wraps, lru_cache
98

109
import click
1110
from nonebot.rule import Rule
12-
from nonebot.adapters import Event
1311
import sqlalchemy.ext.asyncio as sa_async
1412
from nonebot.permission import Permission
13+
from sqlalchemy.util import greenlet_spawn
1514
from sqlalchemy import URL, Table, MetaData
15+
from nonebot.message import run_postprocessor
1616
from nonebot.params import Depends, DefaultParam
1717
from nonebot.plugin import Plugin, PluginMetadata
18-
from sqlalchemy.util import ScopedRegistry, greenlet_spawn
1918
from sqlalchemy.log import Identified, _qual_logger_name_for_cls
20-
from nonebot.message import run_postprocessor, event_postprocessor
2119
from nonebot.matcher import Matcher, current_event, current_matcher
2220
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
2321
from nonebot import logger, require, get_driver, get_plugin_by_module_name
@@ -66,7 +64,7 @@
6664
_metadatas: dict[str, MetaData]
6765
_plugins: dict[str, Plugin]
6866
_session_factory: sa_async.async_sessionmaker[sa_async.AsyncSession]
69-
_scoped_sessions: ScopedRegistry[sa_async.async_scoped_session[sa_async.AsyncSession]]
67+
_scoped_sessions: sa_async.async_scoped_session[sa_async.AsyncSession]
7068

7169
_data_dir = get_data_dir(__plugin_meta__.name)
7270
_driver = get_driver()
@@ -104,16 +102,12 @@ def _init_orm():
104102
**plugin_config.sqlalchemy_session_options,
105103
}
106104
)
107-
_scoped_sessions = ScopedRegistry(
108-
lambda: sa_async.async_scoped_session(
109-
_session_factory, lambda: current_matcher.get(None)
110-
),
111-
lambda: id(current_event.get(None)),
105+
_scoped_sessions = sa_async.async_scoped_session(
106+
_session_factory,
107+
lambda: (id(current_event.get(None)), current_matcher.get(None)),
112108
)
113109

114-
# XXX: workaround for https://github.com/nonebot/nonebot2/issues/2475
115-
event_postprocessor(_clear_scoped_session)
116-
run_postprocessor(_close_scoped_session)
110+
run_postprocessor(_scoped_sessions.remove)
117111

118112

119113
@wraps(lambda: None) # NOTE: for dependency injection
@@ -129,7 +123,7 @@ def get_session(**local_kw: Any) -> sa_async.AsyncSession:
129123

130124
async def get_scoped_session() -> sa_async.async_scoped_session[sa_async.AsyncSession]:
131125
try:
132-
return _scoped_sessions()
126+
return _scoped_sessions
133127
except NameError:
134128
raise RuntimeError("nonebot-plugin-orm 未初始化") from None
135129

@@ -139,22 +133,6 @@ async def get_scoped_session() -> sa_async.async_scoped_session[sa_async.AsyncSe
139133
]
140134

141135

142-
# @event_postprocessor
143-
def _clear_scoped_session(event: Event) -> None:
144-
with suppress(KeyError):
145-
del _scoped_sessions.registry[id(event)]
146-
147-
148-
# @run_postprocessor
149-
async def _close_scoped_session(event: Event, matcher: Matcher) -> None:
150-
with suppress(KeyError):
151-
session: sa_async.AsyncSession = _scoped_sessions.registry[
152-
id(event)
153-
].registry.registry[matcher]
154-
del _scoped_sessions.registry[id(event)].registry.registry[matcher]
155-
await session.close()
156-
157-
158136
def _create_engine(engine: str | URL | dict[str, Any] | AsyncEngine) -> AsyncEngine:
159137
if isinstance(engine, AsyncEngine):
160138
return engine

nonebot_plugin_orm/param.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ def _check_param(
165165

166166
if depends_inner is not None:
167167
dependency = compile_dependency(depends_inner.dependency, option)
168-
elif all(map(isclass, models)) and all(map(issubclass, models, repeat(Model))):
168+
elif all(map(isclass, models)) and all(
169+
map(issubclass, cast(Tuple[type, ...], models), repeat(Model))
170+
):
169171
models = cast(Tuple[Type[Model], ...], models)
170172
dependency = compile_dependency(
171173
select(*models).where(

0 commit comments

Comments
 (0)