Skip to content

Commit 95624a6

Browse files
committed
✨ feat(alembic): sync
1 parent 45405ae commit 95624a6

File tree

4 files changed

+151
-103
lines changed

4 files changed

+151
-103
lines changed

nonebot_plugin_orm/__init__.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
import logging
55
from asyncio import gather
66
from operator import methodcaller
7-
from functools import wraps, partial
87
from typing import Any, AsyncGenerator
8+
from functools import wraps, partial, lru_cache
99

1010
from nonebot.params import Depends
11-
from nonebot.plugin import PluginMetadata
12-
import sqlalchemy.ext.asyncio as sa_asyncio
11+
import sqlalchemy.ext.asyncio as sa_async
1312
from sqlalchemy.util import greenlet_spawn
1413
from nonebot.matcher import current_matcher
15-
from nonebot import logger, require, get_driver
14+
from nonebot.plugin import Plugin, PluginMetadata
1615
from sqlalchemy import URL, Table, MetaData, make_url
1716
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
17+
from nonebot import logger, require, get_driver, get_plugin_by_module_name
1818

1919
from . import migrate
2020
from .config import Config
@@ -55,6 +55,7 @@
5555
"merge",
5656
"upgrade",
5757
"downgrade",
58+
"sync",
5859
"show",
5960
"history",
6061
"heads",
@@ -77,7 +78,8 @@
7778
_binds: dict[type[Model], AsyncEngine]
7879
_engines: dict[str, AsyncEngine]
7980
_metadatas: dict[str, MetaData]
80-
_session_factory: sa_asyncio.async_sessionmaker[AsyncSession]
81+
_plugins: dict[str, Plugin]
82+
_session_factory: sa_async.async_sessionmaker[AsyncSession]
8183

8284
_driver = get_driver()
8385

@@ -90,34 +92,33 @@ async def init_orm() -> None:
9092
if plugin_config.alembic_startup_check:
9193
await greenlet_spawn(migrate.check, alembic_config)
9294
else:
93-
logger.warning("跳过启动检查,直接创建所有表并标记数据库为最新修订版本")
94-
await migrate._upgrade_fast(alembic_config)
95-
await greenlet_spawn(migrate.stamp, alembic_config)
95+
logger.warning("跳过启动检查,正在同步数据库模式...")
96+
await greenlet_spawn(migrate.sync, alembic_config)
9697

9798

9899
def _init_orm():
99100
global _session_factory
100101

101102
_init_engines()
102103
_init_table()
103-
_session_factory = sa_asyncio.async_sessionmaker(
104+
_session_factory = sa_async.async_sessionmaker(
104105
_engines[""], binds=_binds, **plugin_config.sqlalchemy_session_options
105106
)
106107

107108

108109
@wraps(lambda: None) # NOTE: for dependency injection
109-
def get_session(**local_kw: Any) -> sa_asyncio.AsyncSession:
110+
def get_session(**local_kw: Any) -> sa_async.AsyncSession:
110111
try:
111112
return _session_factory(**local_kw)
112113
except NameError:
113114
raise RuntimeError("nonebot-plugin-orm 未初始化") from None
114115

115116

116-
AsyncSession = Annotated[sa_asyncio.AsyncSession, Depends(get_session)]
117+
AsyncSession = Annotated[sa_async.AsyncSession, Depends(get_session)]
117118

118119

119120
async def get_scoped_session() -> (
120-
AsyncGenerator[sa_asyncio.async_scoped_session[AsyncSession], None]
121+
AsyncGenerator[sa_async.async_scoped_session[AsyncSession], None]
121122
):
122123
try:
123124
scoped_session = async_scoped_session(
@@ -131,7 +132,7 @@ async def get_scoped_session() -> (
131132

132133

133134
async_scoped_session = Annotated[
134-
sa_asyncio.async_scoped_session[AsyncSession], Depends(get_scoped_session)
135+
sa_async.async_scoped_session[sa_async.AsyncSession], Depends(get_scoped_session)
135136
]
136137

137138

@@ -186,20 +187,21 @@ def _init_engines():
186187

187188

188189
def _init_table():
189-
global _binds, _metadatas
190+
global _binds, _metadatas, _plugins
190191

191192
_binds = {}
193+
_plugins = {}
192194

193-
if len(_engines) == 1: # NOTE: common case: only default engine
194-
_metadatas = {"": Model.metadata}
195-
return
196-
195+
_get_plugin_by_module_name = lru_cache(None)(get_plugin_by_module_name)
197196
for model in Model.__subclasses__():
198197
table: Table | None = getattr(model, "__table__", None)
199198

200199
if table is None or (bind_key := table.info.get("bind_key")) is None:
201200
continue
202201

202+
if plugin := _get_plugin_by_module_name(model.__module__):
203+
_plugins[plugin.name] = plugin
204+
203205
_binds[model] = _engines.get(bind_key, _engines[""])
204206
table.to_metadata(_metadatas.get(bind_key, _metadatas[""]))
205207

nonebot_plugin_orm/__main__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def init(*args, **kwargs) -> None:
8989
@orm.command()
9090
@click.option("-m", "--message", help="描述")
9191
@click.option("--sql", is_flag=True, help="以 SQL 的形式输出修订脚本")
92-
@click.option("--head", default="head", help="基准版本")
92+
@click.option("--head", help="基准版本")
9393
@click.option("--splice", is_flag=True, help="允许非头部修订作为基准版本")
9494
@click.option("--branch-label", help="分支标签")
9595
@click.option(
@@ -128,11 +128,6 @@ def merge(*args, **kwargs) -> Iterable[Script]:
128128
@click.argument("revision", required=False)
129129
@click.option("--sql", is_flag=True, help="以 SQL 的形式输出修订脚本")
130130
@click.option("--tag", help="一个任意的字符串, 可在自定义的 env.py 中使用")
131-
@click.option(
132-
"--fast",
133-
is_flag=True,
134-
help="快速升级到最新版本,不运行修订脚本,直接创建当前的表(只应该在数据库为空、修订较多且只有表结构更改时使用)",
135-
)
136131
@click.pass_obj
137132
def upgrade(*args, **kwargs) -> None:
138133
"""升级到较新版本。"""
@@ -149,6 +144,14 @@ def downgrade(*args, **kwargs) -> None:
149144
return migrate.downgrade(*args, **kwargs)
150145

151146

147+
@orm.command()
148+
@click.argument("revision", required=False)
149+
@click.pass_obj
150+
def sync(*args, **kwargs) -> None:
151+
"""同步数据库模式 (仅用于开发)。"""
152+
return migrate.sync(*args, **kwargs)
153+
154+
152155
@orm.command()
153156
@click.argument("revs", nargs=-1)
154157
@click.pass_obj

0 commit comments

Comments
 (0)