Skip to content

Commit 259bbfd

Browse files
committed
♻️ refactor: split init_orm()
1 parent de4b6e2 commit 259bbfd

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

nonebot_plugin_orm/__init__.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,8 @@
8484

8585

8686
@_driver.on_startup
87-
async def init_orm():
88-
global _session_factory
89-
90-
_init_engines()
91-
_init_table()
92-
_session_factory = sa_asyncio.async_sessionmaker(
93-
_engines[""], binds=_binds, **plugin_config.sqlalchemy_session_options
94-
)
87+
async def init_orm() -> None:
88+
_init_orm()
9589

9690
with migrate.AlembicConfig(stdout=StreamToLogger()) as alembic_config:
9791
if plugin_config.alembic_startup_check:
@@ -102,6 +96,16 @@ async def init_orm():
10296
await greenlet_spawn(migrate.stamp, alembic_config)
10397

10498

99+
def _init_orm():
100+
global _session_factory
101+
102+
_init_engines()
103+
_init_table()
104+
_session_factory = sa_asyncio.async_sessionmaker(
105+
_engines[""], binds=_binds, **plugin_config.sqlalchemy_session_options
106+
)
107+
108+
105109
@wraps(lambda: None) # NOTE: for dependency injection
106110
def get_session(**local_kw: Any) -> sa_asyncio.AsyncSession:
107111
try:

nonebot_plugin_orm/__main__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import click
99
from alembic.script import Script
10+
from sqlalchemy.util import greenlet_spawn
1011

1112
from . import migrate
1213
from .config import plugin_config
@@ -205,7 +206,7 @@ def stamp(*args, **kwargs) -> None:
205206
@orm.command()
206207
@click.argument("rev", default="current")
207208
@click.pass_obj
208-
def edit(*args, **kwargs):
209+
def edit(*args, **kwargs) -> None:
209210
"""使用 $EDITOR 编辑修订文件。"""
210211
return migrate.edit(*args, **kwargs)
211212

@@ -218,12 +219,14 @@ def ensure_version(*args, **kwargs) -> None:
218219
return migrate.ensure_version(*args, **kwargs)
219220

220221

221-
def main():
222-
from . import _init_table, _init_engines
222+
def main(*args, **kwargs) -> None:
223+
from . import _init_orm
223224

224-
_init_engines()
225-
_init_table()
226-
orm(prog_name="nb orm")
225+
if not (args or kwargs):
226+
kwargs["prog_name"] = "nb orm"
227+
228+
_init_orm()
229+
orm(*args, **kwargs)
227230

228231

229232
if __name__ == "__main__":

0 commit comments

Comments
 (0)