Skip to content

Commit 4c6bc06

Browse files
committed
💥 refactor(alembic)!: nested version locations
1 parent 3b28499 commit 4c6bc06

File tree

3 files changed

+103
-72
lines changed

3 files changed

+103
-72
lines changed

nonebot_plugin_orm/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from alembic import context
44
from sqlalchemy.sql.schema import SchemaItem
55

6-
from nonebot_plugin_orm import migrate
6+
from . import migrate
77

88

99
def no_drop_table(

nonebot_plugin_orm/migrate.py

Lines changed: 94 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88
from pprint import pformat
99
from argparse import Namespace
10+
from operator import attrgetter
1011
from typing import Any, TextIO, cast
1112
from tempfile import TemporaryDirectory
1213
from configparser import DuplicateSectionError
@@ -16,9 +17,9 @@
1617
import click
1718
import alembic
1819
import sqlalchemy
19-
from nonebot import logger
2020
from alembic.config import Config
2121
from sqlalchemy.util import asbool
22+
from nonebot import logger, get_plugin
2223
from sqlalchemy import MetaData, Connection
2324
from alembic.util.editor import open_in_editor
2425
from alembic.script import Script, ScriptDirectory
@@ -32,7 +33,7 @@
3233
render_python_code,
3334
)
3435

35-
from .utils import is_editable, return_progressbar
36+
from .utils import is_editable, get_parent_plugins, return_progressbar
3637

3738
if sys.version_info >= (3, 11):
3839
from typing import Self
@@ -91,20 +92,28 @@ def __init__(
9192
) -> None:
9293
from . import _engines, _metadatas, plugin_config
9394

94-
if file_ is None and Path("alembic.ini").is_file():
95-
file_ = "alembic.ini"
95+
self._exit_stack = ExitStack()
96+
self._plugin_version_locations = {}
97+
self._temp_dir = Path(self._exit_stack.enter_context(TemporaryDirectory()))
98+
99+
if file_ is None and isinstance(plugin_config.alembic_config, Path):
100+
file_ = plugin_config.alembic_config
96101

97102
if plugin_config.alembic_script_location:
98103
script_location = plugin_config.alembic_script_location
99104
elif (
100-
Path("migrations/env.py").is_file()
101-
and Path("migrations/script.py.mako").is_file()
105+
Path("migrations", "env.py").is_file()
106+
and Path("migrations", "script.py.mako").is_file()
102107
):
103-
script_location = "migrations"
108+
script_location = Path("migrations")
104109
elif len(_engines) == 1:
105-
script_location = str(Path(__file__).parent / "templates" / "generic")
110+
script_location = self._exit_stack.enter_context(
111+
as_file(files(__name__) / "templates" / "generic")
112+
)
106113
else:
107-
script_location = str(Path(__file__).parent / "templates" / "multidb")
114+
script_location = self._exit_stack.enter_context(
115+
as_file(files(__name__) / "templates" / "multidb")
116+
)
108117

109118
super().__init__(
110119
file_,
@@ -130,10 +139,6 @@ def __init__(
130139
},
131140
)
132141

133-
self._exit_stack = ExitStack()
134-
self._plugin_version_locations = {}
135-
self._temp_dir = Path(self._exit_stack.enter_context(TemporaryDirectory()))
136-
137142
self._init_post_write_hooks()
138143
self._init_version_locations()
139144

@@ -173,7 +178,7 @@ def move_script(self, script: Script) -> Path:
173178
except ValueError:
174179
return script_path
175180

176-
plugin_name = (script_path.parent.parts or ("",))[0]
181+
plugin_name = script_path.parent.name
177182
if version_location := self._plugin_version_locations.get(plugin_name):
178183
pass
179184
elif version_location := self._plugin_version_locations.get(""):
@@ -185,12 +190,8 @@ def move_script(self, script: Script) -> Path:
185190
)
186191
return script_path
187192

188-
(version_location / script_path.relative_to(plugin_name).parent).mkdir(
189-
parents=True, exist_ok=True
190-
)
191-
return shutil.move(
192-
script.path, version_location / script_path.relative_to(plugin_name)
193-
)
193+
version_location.mkdir(parents=True, exist_ok=True)
194+
return shutil.move(script.path, version_location)
194195

195196
def _add_post_write_hook(self, name: str, **kwargs: str) -> None:
196197
self.set_section_option(
@@ -238,16 +239,17 @@ def _init_version_locations(self) -> None:
238239

239240
alembic_version_locations = plugin_config.alembic_version_locations
240241
if isinstance(alembic_version_locations, dict):
241-
main_version_location = Path(
242-
alembic_version_locations.get("", "migrations/versions")
243-
)
242+
main_version_location = alembic_version_locations.get("")
244243
else:
245-
main_version_location = Path(
246-
alembic_version_locations or "migrations/versions"
247-
)
248-
self._plugin_version_locations[""] = main_version_location
244+
main_version_location = alembic_version_locations
245+
246+
self._plugin_version_locations[""] = main_version_location or Path(
247+
"migrations", "versions"
248+
)
249249

250-
version_locations = {_data_dir / "migrations": ""}
250+
temp_version_locations: dict[Path, Path] = {
251+
_data_dir / "migrations": self._temp_dir
252+
}
251253

252254
for plugin in _plugins.values():
253255
if plugin.metadata and (
@@ -257,39 +259,53 @@ def _init_version_locations(self) -> None:
257259
else:
258260
version_location = files(plugin.module) / "migrations"
259261

260-
if is_editable(plugin) and isinstance(version_location, Path):
262+
temp_version_location = Path(
263+
*map(attrgetter("name"), reversed(list(get_parent_plugins(plugin)))),
264+
)
265+
266+
if (
267+
not main_version_location
268+
and is_editable(plugin)
269+
and isinstance(version_location, Path)
270+
):
261271
self._plugin_version_locations[plugin.name] = version_location
262272
else:
263273
self._plugin_version_locations[plugin.name] = (
264-
main_version_location / plugin.name
274+
self._plugin_version_locations[""] / temp_version_location
265275
)
266276

267-
version_locations[
277+
temp_version_locations[
268278
self._exit_stack.enter_context(as_file(version_location))
269-
] = plugin.name
279+
] = (self._temp_dir / temp_version_location)
270280

271281
if isinstance(alembic_version_locations, dict):
272-
for name, path in alembic_version_locations.items():
273-
path = self._plugin_version_locations[name] = Path(path)
274-
version_locations[path] = name
282+
for plugin_name, version_location in alembic_version_locations.items():
283+
if not (plugin := get_plugin(plugin_name)):
284+
continue
285+
286+
version_location = Path(version_location)
287+
self._plugin_version_locations[plugin_name] = version_location
288+
temp_version_locations[version_location] = self._temp_dir.joinpath(
289+
*map(
290+
attrgetter("name"),
291+
reversed(list(get_parent_plugins(plugin))),
292+
)
293+
)
275294

276-
version_locations[main_version_location] = ""
295+
temp_version_locations[self._plugin_version_locations[""]] = self._temp_dir
277296

278-
for src, dst in version_locations.items():
297+
for src, dst in temp_version_locations.items():
298+
dst.mkdir(parents=True, exist_ok=True)
279299
with suppress(FileNotFoundError, shutil.Error):
280-
shutil.copytree(src, self._temp_dir / dst, dirs_exist_ok=True)
300+
shutil.copytree(src, dst, dirs_exist_ok=True)
281301

282302
pathsep = _SPLIT_ON_PATH[self.get_main_option("version_path_separator")]
283303
self.set_main_option(
284304
"version_locations",
285305
pathsep.join(
286-
map(
287-
str,
288-
(
289-
self._temp_dir,
290-
*filter(methodcaller("is_dir"), self._temp_dir.iterdir()),
291-
),
292-
)
306+
str(path)
307+
for path in self._temp_dir.glob("**")
308+
if path.name != "__pycache__"
293309
),
294310
)
295311

@@ -301,11 +317,10 @@ def ignore(path: str, names: list[str]) -> set[str]:
301317
path_ = Path(path)
302318

303319
return set(
304-
filter(
305-
lambda name: Path(name).suffix in {".py", ".pyc", ".pyo"}
306-
and path_ / name not in run_script_path,
307-
names,
308-
)
320+
name
321+
for name in names
322+
if Path(name).suffix in {".py", ".pyc", ".pyo"}
323+
and path_ / name not in run_script_path
309324
)
310325

311326
run_script_path = set(
@@ -377,7 +392,7 @@ def revision(
377392
head: str | None = None,
378393
splice: bool = False,
379394
branch_label: str | None = None,
380-
version_path: Path | None = None,
395+
version_path: str | Path | None = None,
381396
rev_id: str | None = None,
382397
depends_on: str | None = None,
383398
process_revision_directives: ProcessRevisionDirectiveFn | None = None,
@@ -401,17 +416,30 @@ def revision(
401416
if head is None:
402417
head = "base" if branch_label else "head"
403418

419+
if not version_path and branch_label and (plugin := _plugins.get(branch_label)):
420+
version_path = str(
421+
config._temp_dir.joinpath(
422+
*map(
423+
attrgetter("name"),
424+
reversed(list(get_parent_plugins(plugin))),
425+
)
426+
)
427+
)
428+
404429
if version_path:
405-
version_locations = config.get_main_option("version_locations")
430+
version_path = Path(version_path).resolve()
431+
version_locations = config.get_main_option("version_locations", "")
406432
pathsep = _SPLIT_ON_PATH[config.get_main_option("version_path_separator")]
407-
config.set_main_option(
408-
"version_locations", f"{version_locations}{pathsep}{version_path}"
409-
)
410-
logger.warning(
411-
f'临时将目录 "{version_path}" 添加到版本目录中, 请稍后将其添加到 ALEMBIC_VERSION_LOCATIONS 中'
412-
)
413-
elif branch_label in _plugins:
414-
version_path = config._temp_dir / branch_label
433+
434+
if version_path in (
435+
Path(path).resolve() for path in version_locations.split(pathsep)
436+
):
437+
config.set_main_option(
438+
"version_locations", f"{version_locations}{pathsep}{version_path}"
439+
)
440+
logger.warning(
441+
f'临时将目录 "{version_path}" 添加到版本目录中, 请稍后将其添加到 ALEMBIC_VERSION_LOCATIONS 中'
442+
)
415443

416444
script = ScriptDirectory.from_config(config)
417445

@@ -740,7 +768,7 @@ def show(config: AlembicConfig, revs: str | Sequence[str] = "current") -> None:
740768
):
741769
script.run_env()
742770

743-
for sc in cast(Tuple[Script], script.get_revisions(revs)):
771+
for sc in cast("tuple[Script]", script.get_revisions(revs)):
744772
config.print_stdout(sc.log_entry)
745773

746774

@@ -828,7 +856,7 @@ def heads(
828856
else:
829857
heads = script.get_revisions(script.get_heads())
830858

831-
for rev in cast(Tuple[Script], heads):
859+
for rev in cast("tuple[Script]", heads):
832860
config.print_stdout(
833861
rev.cmd_format(verbose, include_branches=True, tree_indicators=False)
834862
)
@@ -880,7 +908,7 @@ def display_version(
880908
"Current revision(s) for %s:",
881909
cast(Connection, context.connection).engine.url.render_as_string(),
882910
)
883-
for sc in cast(Set[Script], script.get_all_current(rev)):
911+
for sc in cast("set[Script]", script.get_all_current(rev)):
884912
config.print_stdout(sc.cmd_format(verbose))
885913

886914
return ()
@@ -960,7 +988,7 @@ def edit_current(rev, _) -> Iterable[StampStep | RevisionStep]:
960988
if not rev:
961989
raise click.UsageError("当前没有迁移")
962990

963-
for sc in cast(Tuple[Script], script.get_revisions(rev)):
991+
for sc in cast("tuple[Script]", script.get_revisions(rev)):
964992
script_path = config.move_script(sc)
965993
open_in_editor(str(script_path))
966994

@@ -969,12 +997,12 @@ def edit_current(rev, _) -> Iterable[StampStep | RevisionStep]:
969997
with EnvironmentContext(config, script, fn=edit_current):
970998
script.run_env()
971999
else:
972-
revs = cast(Tuple[Script, ...], script.get_revisions(rev))
1000+
revs = cast("tuple[Script, ...]", script.get_revisions(rev))
9731001

9741002
if not revs:
9751003
raise click.BadParameter(f'没有 "{rev}" 指示的迁移脚本')
9761004

977-
for sc in cast(Tuple[Script], revs):
1005+
for sc in cast("tuple[Script]", revs):
9781006
script_path = config.move_script(sc)
9791007
open_in_editor(str(script_path))
9801008

nonebot_plugin_orm/utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from typing import Any, TypeVar
1212
from typing_extensions import Annotated
1313
from dataclasses import field, dataclass
14-
from collections.abc import Callable, Iterable
1514
from inspect import Parameter, Signature, isclass
15+
from collections.abc import Callable, Iterable, Generator
1616
from importlib.metadata import Distribution, PackageNotFoundError, distribution
1717

1818
import click
@@ -218,16 +218,19 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Iterable[_T]:
218218
return wrapper
219219

220220

221+
def get_parent_plugins(plugin: Plugin | None) -> Generator[Plugin, Any, None]:
222+
while plugin:
223+
yield plugin
224+
plugin = plugin.parent_plugin
225+
226+
221227
pkgs = packages_distributions()
222228

223229

224230
def is_editable(plugin: Plugin) -> bool:
225-
"""Check if the distribution is installed in editable mode"""
226-
while plugin.parent_plugin:
227-
plugin = plugin.parent_plugin
231+
*_, plugin = get_parent_plugins(plugin)
228232

229233
path = files(plugin.module)
230-
231234
if not isinstance(path, Path) or "site-packages" in path.parts:
232235
return False
233236

0 commit comments

Comments
 (0)