Skip to content

Commit 010ca01

Browse files
committed
🐛 fix(alembic): patch session during migration
1 parent 317e401 commit 010ca01

File tree

4 files changed

+99
-84
lines changed

4 files changed

+99
-84
lines changed

nonebot_plugin_orm/__init__.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
import logging
44
from argparse import Namespace
5-
from functools import wraps, lru_cache
5+
from functools import cache, wraps
6+
from collections.abc import Generator
7+
from contextlib import contextmanager
68
from typing_extensions import Any, Annotated
79

810
import click
911
from nonebot.rule import Rule
12+
from alembic.op import get_bind
1013
import sqlalchemy.ext.asyncio as sa_async
1114
from nonebot.permission import Permission
1215
from sqlalchemy.util import greenlet_spawn
@@ -16,8 +19,8 @@
1619
from nonebot.plugin import Plugin, PluginMetadata
1720
from sqlalchemy.log import Identified, _qual_logger_name_for_cls
1821
from nonebot.matcher import Matcher, current_event, current_matcher
19-
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
2022
from nonebot import logger, require, get_driver, get_plugin_by_module_name
23+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncConnection, create_async_engine
2124

2225
from . import migrate
2326
from .param import ORMParam
@@ -101,22 +104,6 @@ async def init_orm() -> None:
101104
await greenlet_spawn(migrate.sync, alembic_config)
102105

103106

104-
def _init_orm():
105-
global _session_factory, _scoped_sessions
106-
107-
_init_engines()
108-
_init_table()
109-
_session_factory = sa_async.async_sessionmaker(
110-
_engines[""], binds=_binds, **plugin_config.sqlalchemy_session_options
111-
)
112-
_scoped_sessions = sa_async.async_scoped_session(
113-
_session_factory,
114-
lambda: (id(current_event.get(None)), current_matcher.get(None)),
115-
)
116-
117-
run_postprocessor(_scoped_sessions.remove)
118-
119-
120107
def get_session(**local_kw: Any) -> sa_async.AsyncSession:
121108
try:
122109
return _session_factory(**local_kw)
@@ -149,6 +136,26 @@ def get_scoped_session() -> sa_async.async_scoped_session[sa_async.AsyncSession]
149136
]
150137

151138

139+
@contextmanager
140+
def _patch_migrate_session() -> Generator[None, Any, None]:
141+
global _session_factory, _scoped_sessions
142+
143+
session_factory, scoped_sessions = _session_factory, _scoped_sessions
144+
145+
_session_factory = sa_async.async_sessionmaker(
146+
AsyncConnection._retrieve_proxy_for_target(get_bind()),
147+
**plugin_config.sqlalchemy_session_options,
148+
)
149+
_scoped_sessions = sa_async.async_scoped_session(
150+
_session_factory,
151+
lambda: (id(current_event.get(None)), current_matcher.get(None)),
152+
)
153+
154+
yield
155+
156+
_session_factory, _scoped_sessions = session_factory, scoped_sessions
157+
158+
152159
def _create_engine(engine: str | URL | dict[str, Any] | AsyncEngine) -> AsyncEngine:
153160
if isinstance(engine, AsyncEngine):
154161
return engine
@@ -200,7 +207,7 @@ def _init_table():
200207
_binds = {}
201208
_plugins = {}
202209

203-
_get_plugin_by_module_name = lru_cache(None)(get_plugin_by_module_name)
210+
_get_plugin_by_module_name = cache(get_plugin_by_module_name)
204211
for model in set(get_subclasses(Model)):
205212
table: Table | None = getattr(model, "__table__", None)
206213

@@ -214,6 +221,22 @@ def _init_table():
214221
table.to_metadata(_metadatas.get(bind_key, _metadatas[""]))
215222

216223

224+
def _init_orm():
225+
global _session_factory, _scoped_sessions
226+
227+
_init_engines()
228+
_init_table()
229+
_session_factory = sa_async.async_sessionmaker(
230+
_engines[""], binds=_binds, **plugin_config.sqlalchemy_session_options
231+
)
232+
_scoped_sessions = sa_async.async_scoped_session(
233+
_session_factory,
234+
lambda: (id(current_event.get(None)), current_matcher.get(None)),
235+
)
236+
237+
run_postprocessor(_scoped_sessions.remove)
238+
239+
217240
def _init_logger():
218241
handler = LoguruHandler()
219242
logging.getLogger("alembic").addHandler(handler)

nonebot_plugin_orm/migrate.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
from alembic.config import Config
2222
from sqlalchemy.util import asbool
2323
from nonebot import logger, get_plugin
24+
from nonebot.matcher import current_matcher
2425
from sqlalchemy import MetaData, Connection
2526
from alembic.util.editor import open_in_editor
2627
from alembic.script import Script, ScriptDirectory
2728
from alembic.util.langhelpers import rev_id as _rev_id
2829
from alembic.operations.ops import UpgradeOps, DowngradeOps
30+
from sqlalchemy.ext.asyncio import AsyncConnection, async_sessionmaker
2931
from alembic.migration import StampStep, RevisionStep, MigrationContext
3032
from alembic.runtime.environment import EnvironmentContext, ProcessRevisionDirectiveFn
3133
from alembic.autogenerate.api import (
@@ -121,21 +123,17 @@ def __init__(
121123
stdout,
122124
cmd_opts,
123125
{
124-
**{
125-
"script_location": script_location,
126-
"prepend_sys_path": ".",
127-
"revision_environment": "true",
128-
"version_path_separator": "os",
129-
},
130-
**config_args,
131-
},
126+
"script_location": script_location,
127+
"prepend_sys_path": ".",
128+
"revision_environment": "true",
129+
"version_path_separator": "os",
130+
}
131+
| dict(config_args),
132132
{
133-
**{
134-
"engines": _engines,
135-
"metadatas": _metadatas,
136-
},
137-
**attributes,
138-
},
133+
"engines": _engines,
134+
"metadatas": _metadatas,
135+
}
136+
| attributes,
139137
)
140138

141139
self._init_post_write_hooks()
@@ -646,7 +644,11 @@ def upgrade(
646644

647645
@return_progressbar
648646
def upgrade(rev, _) -> Iterable[StampStep | RevisionStep]:
649-
yield from script._upgrade_revs(revision, rev)
647+
from . import _patch_migrate_session
648+
649+
with _patch_migrate_session():
650+
yield from script._upgrade_revs(revision, rev)
651+
650652
_move_run_scripts(config, script, revision)
651653

652654
with EnvironmentContext(
@@ -691,7 +693,11 @@ def downgrade(
691693

692694
@return_progressbar
693695
def downgrade(rev, _) -> Iterable[StampStep | RevisionStep]:
694-
yield from script._downgrade_revs(revision, rev)
696+
from . import _patch_migrate_session
697+
698+
with _patch_migrate_session():
699+
yield from script._downgrade_revs(revision, rev)
700+
695701
_move_run_scripts(config, script, revision)
696702

697703
with EnvironmentContext(

nonebot_plugin_orm/templates/generic/env.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4-
from typing import cast
4+
from typing import Any, cast
55

66
from alembic import context
77
from sqlalchemy import Connection
@@ -36,34 +36,26 @@ def run_migrations_offline() -> None:
3636
在这里调用 context.execute() 会将给定的字符串写入到脚本输出.
3737
"""
3838

39-
context.configure(
40-
**{
41-
**dict(
42-
url=engine.url,
43-
dialect_opts={"paramstyle": "named"},
44-
target_metadata=target_metadata,
45-
literal_binds=True,
46-
),
47-
**plugin_config.alembic_context,
48-
}
49-
)
39+
opts: dict[str, Any] = {
40+
"url": engine.url,
41+
"dialect_opts": {"paramstyle": "named"},
42+
"target_metadata": target_metadata,
43+
"literal_binds": True,
44+
} | plugin_config.alembic_context
45+
context.configure(**opts)
5046

5147
with context.begin_transaction():
5248
context.run_migrations()
5349

5450

5551
def do_run_migrations(connection: Connection) -> None:
56-
context.configure(
57-
**{
58-
**dict(
59-
connection=connection,
60-
render_as_batch=True,
61-
target_metadata=target_metadata,
62-
include_object=no_drop_table,
63-
),
64-
**plugin_config.alembic_context,
65-
}
66-
)
52+
opts: dict[str, Any] = {
53+
"connection": connection,
54+
"render_as_batch": True,
55+
"target_metadata": target_metadata,
56+
"include_object": no_drop_table,
57+
} | plugin_config.alembic_context
58+
context.configure(**opts)
6759

6860
with context.begin_transaction():
6961
context.run_migrations()

nonebot_plugin_orm/templates/multidb/env.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4-
from typing import cast
4+
from typing import Any, cast
55

66
from alembic import context
77
from sqlalchemy.util import await_only
@@ -51,37 +51,31 @@ def run_migrations_offline() -> None:
5151
config.print_stdout(f"迁移数据库 {name or '<default>'} 中 ...")
5252
file_ = f"{name}.sql"
5353
with open(file_, "w") as buffer:
54-
context.configure(
55-
**{
56-
**dict(
57-
url=engine.url,
58-
dialect_opts={"paramstyle": "named"},
59-
output_buffer=buffer,
60-
target_metadata=target_metadatas["name"],
61-
literal_binds=True,
62-
),
63-
**plugin_config.alembic_context,
64-
}
65-
)
54+
opts: dict[str, Any] = {
55+
"url": engine.url,
56+
"dialect_opts": {"paramstyle": "named"},
57+
"output_buffer": buffer,
58+
"target_metadata": target_metadatas[name],
59+
"literal_binds": True,
60+
} | plugin_config.alembic_context
61+
context.configure(**opts)
62+
6663
with context.begin_transaction():
6764
context.run_migrations(name=name)
6865
config.print_stdout(f"将输出写入到 {file_}")
6966

7067

7168
def do_run_migrations(conn: Connection, name: str, metadata: MetaData) -> None:
72-
context.configure(
73-
**{
74-
**dict(
75-
connection=conn,
76-
render_as_batch=True,
77-
target_metadata=metadata,
78-
include_object=no_drop_table,
79-
upgrade_token=f"{name}_upgrades",
80-
downgrade_token=f"{name}_downgrades",
81-
),
82-
**plugin_config.alembic_context,
83-
}
84-
)
69+
opts: dict[str, Any] = {
70+
"connection": conn,
71+
"render_as_batch": True,
72+
"target_metadata": metadata,
73+
"include_object": no_drop_table,
74+
"upgrade_token": f"{name}_upgrades",
75+
"downgrade_token": f"{name}_downgrades",
76+
} | plugin_config.alembic_context
77+
context.configure(**opts)
78+
8579
context.run_migrations(name=name)
8680

8781

0 commit comments

Comments
 (0)