Skip to content

Commit 94d7fe4

Browse files
committed
⚡ perf(sqla)!: sync function and async dependency of session
1 parent 16acc6f commit 94d7fe4

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

nonebot_plugin_orm/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from . import migrate
2424
from .param import ORMParam
2525
from .config import Config, plugin_config
26-
from .utils import LoguruHandler, StreamToLogger, get_subclasses
26+
from .utils import LoguruHandler, StreamToLogger, coroutine, get_subclasses
2727

2828
if sys.version_info >= (3, 9):
2929
from typing import Annotated
@@ -118,18 +118,21 @@ def get_session(**local_kw: Any) -> sa_async.AsyncSession:
118118
raise RuntimeError("nonebot-plugin-orm 未初始化") from None
119119

120120

121-
AsyncSession = Annotated[sa_async.AsyncSession, Depends(get_session)]
121+
# NOTE: NoneBot DI will run sync function in thread pool executor,
122+
# which is poor performance for this simple function, so we wrap it as a coroutine function.
123+
AsyncSession = Annotated[sa_async.AsyncSession, Depends(coroutine(get_session))]
122124

123125

124-
async def get_scoped_session() -> sa_async.async_scoped_session[sa_async.AsyncSession]:
126+
def get_scoped_session() -> sa_async.async_scoped_session[sa_async.AsyncSession]:
125127
try:
126128
return _scoped_sessions
127129
except NameError:
128130
raise RuntimeError("nonebot-plugin-orm 未初始化") from None
129131

130132

131133
async_scoped_session = Annotated[
132-
sa_async.async_scoped_session[sa_async.AsyncSession], Depends(get_scoped_session)
134+
sa_async.async_scoped_session[sa_async.AsyncSession],
135+
Depends(coroutine(get_scoped_session)),
133136
]
134137

135138

nonebot_plugin_orm/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,14 @@ def get_subclasses(cls: type[_T]) -> Generator[type[_T], None, None]:
251251
yield from get_subclasses(subclass)
252252

253253

254+
def coroutine(func: Callable[_P, _T]) -> Callable[_P, Coroutine[Any, Any, _T]]:
255+
@wraps(func)
256+
async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
257+
return func(*args, **kwargs)
258+
259+
return wrapper
260+
261+
254262
if sys.version_info >= (3, 10):
255263
from inspect import get_annotations as get_annotations # nopycln: import
256264
else:

0 commit comments

Comments
 (0)