Skip to content

Commit c1f08e1

Browse files
committed
💥 refactor(alembic)!: AlembicConfig initialization and version locations handling
1 parent 10c8df6 commit c1f08e1

File tree

5 files changed

+186
-117
lines changed

5 files changed

+186
-117
lines changed

nonebot_plugin_orm/__main__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
from typing import Iterable
55
from argparse import Namespace
6+
from warnings import catch_warnings, filterwarnings
67

78
import click
89
from alembic.script import Script
@@ -38,15 +39,19 @@ def orm(
3839
ctx: click.Context, config: Path, name: str, x: tuple[str, ...], quite: bool
3940
) -> None:
4041
ctx.show_default = True
42+
use_tempdir = ctx.invoked_subcommand in ("revision", "merge", "edit")
4143

4244
if isinstance(plugin_config.alembic_config, AlembicConfig):
4345
ctx.obj = plugin_config.alembic_config
4446
else:
4547
ctx.obj = AlembicConfig(
46-
config, name, cmd_opts=Namespace(config=config, name=name, x=x, quite=quite)
48+
config, name, cmd_opts=Namespace(**ctx.params), use_tempdir=use_tempdir
4749
)
4850

49-
ctx.with_resource(ctx.obj)
51+
ctx.call_on_close(ctx.obj.close)
52+
if use_tempdir:
53+
ctx.with_resource(catch_warnings())
54+
filterwarnings("ignore", r"Revision \w* is present more than once", UserWarning)
5055

5156

5257
@orm.result_callback()

nonebot_plugin_orm/migrate.py

Lines changed: 118 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@
44
import sys
55
import shutil
66
from pathlib import Path
7+
from itertools import chain
78
from argparse import Namespace
89
from operator import methodcaller
9-
from typing import Any, Tuple, TextIO, cast
10+
from tempfile import TemporaryDirectory
1011
from configparser import DuplicateSectionError
11-
from tempfile import TemporaryDirectory, tempdir
12-
from contextlib import AsyncExitStack, suppress, contextmanager
12+
from typing_extensions import ParamSpec, Concatenate
13+
from typing import Any, Tuple, TextIO, TypeVar, Callable, cast
1314
from collections.abc import Mapping, Iterable, Sequence, Generator
15+
from contextlib import ExitStack, AsyncExitStack, suppress, contextmanager
1416

1517
import click
1618
from alembic.config import Config
17-
from nonebot.plugin import Plugin
1819
from alembic.operations.ops import UpgradeOps
1920
from alembic.util.editor import open_in_editor
2021
from alembic.script import Script, ScriptDirectory
@@ -29,8 +30,10 @@
2930

3031
if sys.version_info >= (3, 12):
3132
from typing import Self
33+
from importlib.resources import files, as_file
3234
else:
3335
from typing_extensions import Self
36+
from importlib_resources import files, as_file
3437

3538

3639
__all__ = (
@@ -52,7 +55,8 @@
5255
"ensure_version",
5356
)
5457

55-
58+
_T = TypeVar("_T")
59+
_P = ParamSpec("_P")
5660
_SPLIT_ON_PATH = {
5761
None: " ",
5862
"space": " ",
@@ -63,8 +67,9 @@
6367

6468

6569
class AlembicConfig(Config):
66-
_tempdir: TemporaryDirectory
70+
_exit_stack: ExitStack
6771
_plugin_version_locations: dict[str, Path]
72+
_temp_dir: TemporaryDirectory | None = None
6873

6974
def __init__(
7075
self,
@@ -75,6 +80,7 @@ def __init__(
7580
cmd_opts: Namespace | None = None,
7681
config_args: Mapping[str, Any] = {},
7782
attributes: dict = {},
83+
use_tempdir: bool = True,
7884
) -> None:
7985
from . import _engines, _metadatas, plugin_config
8086

@@ -113,64 +119,23 @@ def __init__(
113119
},
114120
)
115121

116-
self._init_post_write_hooks()
117-
118-
def __enter__(self) -> Self:
119-
from . import plugin_config
120-
121-
self._tempdir = TemporaryDirectory()
122+
self._exit_stack = ExitStack()
122123
self._plugin_version_locations = {}
124+
if use_tempdir:
125+
self._temp_dir = TemporaryDirectory()
126+
self._exit_stack.enter_context(self._temp_dir)
123127

124-
if self.get_main_option("version_locations"):
125-
# NOTE: skip if explicitly set
126-
return self
127-
128-
if isinstance(plugin_config.alembic_version_locations, dict):
129-
if version_locations := plugin_config.alembic_version_locations.get(""):
130-
self._plugin_version_locations[""] = Path(version_locations)
131-
else:
132-
self._plugin_version_locations[""] = Path(
133-
plugin_config.alembic_version_locations or "migrations/versions"
134-
)
135-
136-
tempdir = Path(self._tempdir.name)
137-
138-
for plugin in get_loaded_plugins():
139-
if not plugin.metadata or not (
140-
version_module := plugin.metadata.extra.get("orm_version_location")
141-
):
142-
continue
143-
144-
version_location = version_module.__path__[0]
145-
if is_editable(plugin):
146-
self._plugin_version_locations[plugin.name] = Path(version_location)
147-
elif version_locations := self._plugin_version_locations.get(""):
148-
self._plugin_version_locations[plugin.name] = (
149-
version_locations / plugin.name
150-
)
151-
shutil.copytree(version_location, tempdir / plugin.name)
152-
153-
if isinstance(plugin_config.alembic_version_locations, dict):
154-
for name, path in plugin_config.alembic_version_locations.items():
155-
with suppress(FileNotFoundError):
156-
shutil.copytree(path, tempdir / name, dirs_exist_ok=True)
157-
self._plugin_version_locations[name] = Path(path)
158-
else:
159-
with suppress(FileNotFoundError):
160-
shutil.copytree(
161-
self._plugin_version_locations[""], tempdir, dirs_exist_ok=True
162-
)
163-
164-
pathsep = _SPLIT_ON_PATH[self.get_main_option("version_path_separator")]
165-
version_location = pathsep.join(
166-
map(str, (tempdir, *filter(methodcaller("is_dir"), tempdir.iterdir())))
167-
)
168-
self.set_main_option("version_locations", version_location)
128+
self._init_post_write_hooks()
129+
self._init_version_locations()
169130

131+
def __enter__(self: Self) -> Self:
170132
return self
171133

172134
def __exit__(self, *_) -> None:
173-
self._tempdir.cleanup()
135+
self.close()
136+
137+
def close(self) -> None:
138+
self._exit_stack.close()
174139

175140
def get_template_directory(self) -> str:
176141
return str(Path(__file__).parent / "templates")
@@ -194,8 +159,11 @@ def status(self, status_msg: str) -> Generator[None, Any, None]:
194159
def move_script(self, script: Script) -> Path:
195160
script_path = Path(script.path)
196161

162+
if not self._temp_dir:
163+
return script_path
164+
197165
try:
198-
script_path = script_path.relative_to(self._tempdir.name)
166+
script_path = script_path.relative_to(self._temp_dir.name)
199167
except ValueError:
200168
return script_path
201169

@@ -256,6 +224,89 @@ def _init_post_write_hooks(self) -> None:
256224
options="REVISION_SCRIPT_FILENAME",
257225
)
258226

227+
def _init_version_locations(self) -> None:
228+
from . import plugin_config
229+
230+
alembic_version_locations = plugin_config.alembic_version_locations
231+
232+
if self.get_main_option("version_locations"):
233+
# NOTE: skip if explicitly set
234+
return
235+
236+
if isinstance(alembic_version_locations, dict):
237+
if _main_version_location := alembic_version_locations.get(""):
238+
main_version_location = self._plugin_version_locations[""] = Path(
239+
_main_version_location
240+
)
241+
else:
242+
main_version_location = None
243+
else:
244+
main_version_location = self._plugin_version_locations[""] = Path(
245+
alembic_version_locations or "migrations/versions"
246+
)
247+
248+
temp_dir = Path(self._temp_dir.name) if self._temp_dir else None
249+
version_locations = {}
250+
251+
for plugin in get_loaded_plugins():
252+
if not plugin.metadata or not (
253+
version_module := plugin.metadata.extra.get("orm_version_location")
254+
):
255+
continue
256+
257+
version_location = files(version_module)
258+
259+
if is_editable(plugin) and isinstance(version_location, Path):
260+
self._plugin_version_locations[plugin.name] = version_location
261+
elif main_version_location:
262+
self._plugin_version_locations[plugin.name] = (
263+
main_version_location / plugin.name
264+
)
265+
266+
version_location = self._exit_stack.enter_context(as_file(version_location))
267+
version_locations[version_location] = plugin.name
268+
269+
if isinstance(alembic_version_locations, dict):
270+
for name, path in alembic_version_locations.items():
271+
path = self._plugin_version_locations[name] = Path(path)
272+
version_locations[path] = name
273+
elif main_version_location:
274+
version_locations[main_version_location] = ""
275+
276+
if temp_dir:
277+
for src, dst in version_locations.items():
278+
with suppress(FileNotFoundError):
279+
shutil.copytree(src, temp_dir / dst, dirs_exist_ok=True)
280+
281+
version_locations = (
282+
temp_dir,
283+
*filter(methodcaller("is_dir"), temp_dir.iterdir()),
284+
)
285+
else:
286+
version_locations = reversed(version_locations)
287+
288+
if main_version_location:
289+
version_locations = chain(
290+
filter(methodcaller("is_dir"), main_version_location.iterdir()),
291+
version_locations,
292+
)
293+
294+
pathsep = _SPLIT_ON_PATH[self.get_main_option("version_path_separator")]
295+
self.set_main_option(
296+
"version_locations", pathsep.join(map(str, version_locations))
297+
)
298+
299+
300+
def use_tempdir(
301+
func: Callable[Concatenate[AlembicConfig, _P], _T]
302+
) -> Callable[Concatenate[AlembicConfig, _P], _T]:
303+
def wrapper(config: AlembicConfig, *args: _P.args, **kwargs: _P.kwargs) -> _T:
304+
if config._temp_dir:
305+
return func(config, *args, **kwargs)
306+
raise RuntimeError("AlembicConfig 未启用临时目录")
307+
308+
return wrapper
309+
259310

260311
def list_templates(config: AlembicConfig) -> None:
261312
"""列出所有可用的模板。
@@ -310,6 +361,7 @@ def init(
310361
)
311362

312363

364+
@use_tempdir
313365
def revision(
314366
config: AlembicConfig,
315367
message: str | None = None,
@@ -351,7 +403,9 @@ def revision(
351403
and plugin.metadata
352404
and plugin.metadata.extra.get("orm_version_location")
353405
):
354-
version_path = Path(config._tempdir.name) / branch_label
406+
version_path = (
407+
Path(cast(TemporaryDirectory, config._temp_dir).name) / branch_label
408+
)
355409

356410
script_directory = ScriptDirectory.from_config(config)
357411

@@ -455,6 +509,7 @@ def retrieve_migrations(rev, context):
455509
config.print_stdout("没有检测到新的升级操作")
456510

457511

512+
@use_tempdir
458513
def merge(
459514
config: AlembicConfig,
460515
revisions: tuple[str, ...],
@@ -530,12 +585,12 @@ def upgrade(
530585
def upgrade(rev, _):
531586
nonlocal fast
532587

533-
if fast and revision in ("head", "heads") and not script.get_all_current(rev):
588+
if fast and revision in {"head", "heads"} and not script.get_all_current(rev):
534589
await_fallback(_upgrade_fast(config))
535590
return script._stamp_revs(revision, rev)
536591
else:
537592
fast = False
538-
return script._upgrade_revs(revision, rev)
593+
return script._upgrade_revs(revision, rev) # type: ignore
539594

540595
with EnvironmentContext(
541596
config,
@@ -817,6 +872,7 @@ def do_stamp(rev, _):
817872
script.run_env()
818873

819874

875+
@use_tempdir
820876
def edit(config: AlembicConfig, rev: str = "current") -> None:
821877
"""使用 `$EDITOR` 编辑修订文件。
822878
@@ -826,9 +882,6 @@ def edit(config: AlembicConfig, rev: str = "current") -> None:
826882
"""
827883

828884
script = ScriptDirectory.from_config(config)
829-
temp_version_locations = click.get_current_context().meta.get(
830-
f"{__name__}.temp_version_locations"
831-
)
832885

833886
if rev == "current":
834887

nonebot_plugin_orm/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
import logging
66
from io import StringIO
7+
from pathlib import Path
78
from typing import TypeVar
89
from contextlib import suppress
910
from functools import wraps, lru_cache
@@ -15,12 +16,16 @@
1516
from nonebot.params import Depends
1617
from nonebot import logger, get_driver
1718

19+
if sys.version_info >= (3, 9):
20+
from importlib.resources import files
21+
else:
22+
from importlib_resources import files
23+
1824
if sys.version_info >= (3, 10):
1925
from typing import ParamSpec
2026
from importlib.metadata import packages_distributions
2127
else:
2228
from typing_extensions import ParamSpec
23-
2429
from importlib_metadata import packages_distributions
2530

2631

@@ -96,7 +101,8 @@ def is_editable(plugin: Plugin) -> bool:
96101
)
97102

98103
if not dist:
99-
return "site-packages" not in plugin.module.__path__[0]
104+
path = files(plugin.module)
105+
return isinstance(path, Path) and "site-packages" not in path.parts
100106

101107
# https://github.com/pdm-project/pdm/blob/fee1e6bffd7de30315e2134e19f9a6f58e15867c/src/pdm/utils.py#L361-L374
102108
if getattr(dist, "link_file", None) is not None:

0 commit comments

Comments
 (0)