Skip to content

Commit fa9640a

Browse files
committed
🚸 feat(alembic): upgrade prompt on startup
1 parent 5ebcc57 commit fa9640a

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

nonebot_plugin_orm/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,10 @@ async def init_orm() -> None:
8383
try:
8484
await greenlet_spawn(migrate.check, alembic_config)
8585
except click.UsageError:
86-
logger.error("启动检查失败")
87-
raise
86+
if not click.confirm("目标数据库未更新到最新迁移, 是否更新?"):
87+
raise
88+
cmd_opts.cmd = (migrate.upgrade, [], [])
89+
await greenlet_spawn(migrate.upgrade, alembic_config)
8890
else:
8991
logger.warning("跳过启动检查, 正在同步数据库模式...")
9092
cmd_opts.cmd = (migrate.sync, ["revision"], [])
@@ -110,7 +112,6 @@ def _init_orm():
110112
run_postprocessor(_scoped_sessions.remove)
111113

112114

113-
@wraps(lambda: None) # NOTE: for dependency injection
114115
def get_session(**local_kw: Any) -> sa_async.AsyncSession:
115116
try:
116117
return _session_factory(**local_kw)
@@ -121,7 +122,8 @@ def get_session(**local_kw: Any) -> sa_async.AsyncSession:
121122
# NOTE: NoneBot DI will run sync function in thread pool executor,
122123
# which is poor performance for this simple function, so we wrap it as a coroutine function.
123124
AsyncSession = Annotated[
124-
sa_async.AsyncSession, Depends(coroutine(get_session), use_cache=False)
125+
sa_async.AsyncSession,
126+
Depends(coroutine(wraps(lambda: None)(get_session)), use_cache=False),
125127
]
126128

127129

nonebot_plugin_orm/migrate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sqlalchemy import MetaData, Connection
2424
from alembic.util.editor import open_in_editor
2525
from alembic.script import Script, ScriptDirectory
26+
from alembic.util import AutogenerateDiffsDetected
2627
from alembic.util.langhelpers import rev_id as _rev_id
2728
from alembic.operations.ops import UpgradeOps, DowngradeOps
2829
from alembic.migration import StampStep, RevisionStep, MigrationContext
@@ -542,7 +543,7 @@ def retrieve_migrations(
542543
migration_script = revision_context.generated_revisions[-1]
543544
diffs = cast(UpgradeOps, migration_script.upgrade_ops).as_diffs()
544545
if diffs:
545-
raise click.UsageError(f"检测到新的升级操作:\n{pformat(diffs)}")
546+
raise AutogenerateDiffsDetected(f"检测到新的升级操作:\n{pformat(diffs)}")
546547
else:
547548
config.print_stdout("没有检测到新的升级操作")
548549

0 commit comments

Comments
 (0)