44import sys
55import shutil
66from pathlib import Path
7+ from itertools import chain
78from argparse import Namespace
89from operator import methodcaller
9- from typing import Any , Tuple , TextIO , cast
10+ from tempfile import TemporaryDirectory
1011from configparser import DuplicateSectionError
11- from tempfile import TemporaryDirectory , tempdir
12- from contextlib import AsyncExitStack , suppress , contextmanager
12+ from typing_extensions import ParamSpec , Concatenate
13+ from typing import Any , Tuple , TextIO , TypeVar , Callable , cast
1314from collections .abc import Mapping , Iterable , Sequence , Generator
15+ from contextlib import ExitStack , AsyncExitStack , suppress , contextmanager
1416
1517import click
1618from alembic .config import Config
17- from nonebot .plugin import Plugin
1819from alembic .operations .ops import UpgradeOps
1920from alembic .util .editor import open_in_editor
2021from alembic .script import Script , ScriptDirectory
2930
3031if sys .version_info >= (3 , 12 ):
3132 from typing import Self
33+ from importlib .resources import files , as_file
3234else :
3335 from typing_extensions import Self
36+ from importlib_resources import files , as_file
3437
3538
3639__all__ = (
5255 "ensure_version" ,
5356)
5457
55-
58+ _T = TypeVar ("_T" )
59+ _P = ParamSpec ("_P" )
5660_SPLIT_ON_PATH = {
5761 None : " " ,
5862 "space" : " " ,
6367
6468
6569class AlembicConfig (Config ):
66- _tempdir : TemporaryDirectory
70+ _exit_stack : ExitStack
6771 _plugin_version_locations : dict [str , Path ]
72+ _temp_dir : TemporaryDirectory | None = None
6873
6974 def __init__ (
7075 self ,
@@ -75,6 +80,7 @@ def __init__(
7580 cmd_opts : Namespace | None = None ,
7681 config_args : Mapping [str , Any ] = {},
7782 attributes : dict = {},
83+ use_tempdir : bool = True ,
7884 ) -> None :
7985 from . import _engines , _metadatas , plugin_config
8086
@@ -113,64 +119,23 @@ def __init__(
113119 },
114120 )
115121
116- self ._init_post_write_hooks ()
117-
118- def __enter__ (self ) -> Self :
119- from . import plugin_config
120-
121- self ._tempdir = TemporaryDirectory ()
122+ self ._exit_stack = ExitStack ()
122123 self ._plugin_version_locations = {}
124+ if use_tempdir :
125+ self ._temp_dir = TemporaryDirectory ()
126+ self ._exit_stack .enter_context (self ._temp_dir )
123127
124- if self .get_main_option ("version_locations" ):
125- # NOTE: skip if explicitly set
126- return self
127-
128- if isinstance (plugin_config .alembic_version_locations , dict ):
129- if version_locations := plugin_config .alembic_version_locations .get ("" ):
130- self ._plugin_version_locations ["" ] = Path (version_locations )
131- else :
132- self ._plugin_version_locations ["" ] = Path (
133- plugin_config .alembic_version_locations or "migrations/versions"
134- )
135-
136- tempdir = Path (self ._tempdir .name )
137-
138- for plugin in get_loaded_plugins ():
139- if not plugin .metadata or not (
140- version_module := plugin .metadata .extra .get ("orm_version_location" )
141- ):
142- continue
143-
144- version_location = version_module .__path__ [0 ]
145- if is_editable (plugin ):
146- self ._plugin_version_locations [plugin .name ] = Path (version_location )
147- elif version_locations := self ._plugin_version_locations .get ("" ):
148- self ._plugin_version_locations [plugin .name ] = (
149- version_locations / plugin .name
150- )
151- shutil .copytree (version_location , tempdir / plugin .name )
152-
153- if isinstance (plugin_config .alembic_version_locations , dict ):
154- for name , path in plugin_config .alembic_version_locations .items ():
155- with suppress (FileNotFoundError ):
156- shutil .copytree (path , tempdir / name , dirs_exist_ok = True )
157- self ._plugin_version_locations [name ] = Path (path )
158- else :
159- with suppress (FileNotFoundError ):
160- shutil .copytree (
161- self ._plugin_version_locations ["" ], tempdir , dirs_exist_ok = True
162- )
163-
164- pathsep = _SPLIT_ON_PATH [self .get_main_option ("version_path_separator" )]
165- version_location = pathsep .join (
166- map (str , (tempdir , * filter (methodcaller ("is_dir" ), tempdir .iterdir ())))
167- )
168- self .set_main_option ("version_locations" , version_location )
128+ self ._init_post_write_hooks ()
129+ self ._init_version_locations ()
169130
131+ def __enter__ (self : Self ) -> Self :
170132 return self
171133
172134 def __exit__ (self , * _ ) -> None :
173- self ._tempdir .cleanup ()
135+ self .close ()
136+
137+ def close (self ) -> None :
138+ self ._exit_stack .close ()
174139
175140 def get_template_directory (self ) -> str :
176141 return str (Path (__file__ ).parent / "templates" )
@@ -194,8 +159,11 @@ def status(self, status_msg: str) -> Generator[None, Any, None]:
194159 def move_script (self , script : Script ) -> Path :
195160 script_path = Path (script .path )
196161
162+ if not self ._temp_dir :
163+ return script_path
164+
197165 try :
198- script_path = script_path .relative_to (self ._tempdir .name )
166+ script_path = script_path .relative_to (self ._temp_dir .name )
199167 except ValueError :
200168 return script_path
201169
@@ -256,6 +224,89 @@ def _init_post_write_hooks(self) -> None:
256224 options = "REVISION_SCRIPT_FILENAME" ,
257225 )
258226
227+ def _init_version_locations (self ) -> None :
228+ from . import plugin_config
229+
230+ alembic_version_locations = plugin_config .alembic_version_locations
231+
232+ if self .get_main_option ("version_locations" ):
233+ # NOTE: skip if explicitly set
234+ return
235+
236+ if isinstance (alembic_version_locations , dict ):
237+ if _main_version_location := alembic_version_locations .get ("" ):
238+ main_version_location = self ._plugin_version_locations ["" ] = Path (
239+ _main_version_location
240+ )
241+ else :
242+ main_version_location = None
243+ else :
244+ main_version_location = self ._plugin_version_locations ["" ] = Path (
245+ alembic_version_locations or "migrations/versions"
246+ )
247+
248+ temp_dir = Path (self ._temp_dir .name ) if self ._temp_dir else None
249+ version_locations = {}
250+
251+ for plugin in get_loaded_plugins ():
252+ if not plugin .metadata or not (
253+ version_module := plugin .metadata .extra .get ("orm_version_location" )
254+ ):
255+ continue
256+
257+ version_location = files (version_module )
258+
259+ if is_editable (plugin ) and isinstance (version_location , Path ):
260+ self ._plugin_version_locations [plugin .name ] = version_location
261+ elif main_version_location :
262+ self ._plugin_version_locations [plugin .name ] = (
263+ main_version_location / plugin .name
264+ )
265+
266+ version_location = self ._exit_stack .enter_context (as_file (version_location ))
267+ version_locations [version_location ] = plugin .name
268+
269+ if isinstance (alembic_version_locations , dict ):
270+ for name , path in alembic_version_locations .items ():
271+ path = self ._plugin_version_locations [name ] = Path (path )
272+ version_locations [path ] = name
273+ elif main_version_location :
274+ version_locations [main_version_location ] = ""
275+
276+ if temp_dir :
277+ for src , dst in version_locations .items ():
278+ with suppress (FileNotFoundError ):
279+ shutil .copytree (src , temp_dir / dst , dirs_exist_ok = True )
280+
281+ version_locations = (
282+ temp_dir ,
283+ * filter (methodcaller ("is_dir" ), temp_dir .iterdir ()),
284+ )
285+ else :
286+ version_locations = reversed (version_locations )
287+
288+ if main_version_location :
289+ version_locations = chain (
290+ filter (methodcaller ("is_dir" ), main_version_location .iterdir ()),
291+ version_locations ,
292+ )
293+
294+ pathsep = _SPLIT_ON_PATH [self .get_main_option ("version_path_separator" )]
295+ self .set_main_option (
296+ "version_locations" , pathsep .join (map (str , version_locations ))
297+ )
298+
299+
300+ def use_tempdir (
301+ func : Callable [Concatenate [AlembicConfig , _P ], _T ]
302+ ) -> Callable [Concatenate [AlembicConfig , _P ], _T ]:
303+ def wrapper (config : AlembicConfig , * args : _P .args , ** kwargs : _P .kwargs ) -> _T :
304+ if config ._temp_dir :
305+ return func (config , * args , ** kwargs )
306+ raise RuntimeError ("AlembicConfig 未启用临时目录" )
307+
308+ return wrapper
309+
259310
260311def list_templates (config : AlembicConfig ) -> None :
261312 """列出所有可用的模板。
@@ -310,6 +361,7 @@ def init(
310361 )
311362
312363
364+ @use_tempdir
313365def revision (
314366 config : AlembicConfig ,
315367 message : str | None = None ,
@@ -351,7 +403,9 @@ def revision(
351403 and plugin .metadata
352404 and plugin .metadata .extra .get ("orm_version_location" )
353405 ):
354- version_path = Path (config ._tempdir .name ) / branch_label
406+ version_path = (
407+ Path (cast (TemporaryDirectory , config ._temp_dir ).name ) / branch_label
408+ )
355409
356410 script_directory = ScriptDirectory .from_config (config )
357411
@@ -455,6 +509,7 @@ def retrieve_migrations(rev, context):
455509 config .print_stdout ("没有检测到新的升级操作" )
456510
457511
512+ @use_tempdir
458513def merge (
459514 config : AlembicConfig ,
460515 revisions : tuple [str , ...],
@@ -530,12 +585,12 @@ def upgrade(
530585 def upgrade (rev , _ ):
531586 nonlocal fast
532587
533- if fast and revision in ( "head" , "heads" ) and not script .get_all_current (rev ):
588+ if fast and revision in { "head" , "heads" } and not script .get_all_current (rev ):
534589 await_fallback (_upgrade_fast (config ))
535590 return script ._stamp_revs (revision , rev )
536591 else :
537592 fast = False
538- return script ._upgrade_revs (revision , rev )
593+ return script ._upgrade_revs (revision , rev ) # type: ignore
539594
540595 with EnvironmentContext (
541596 config ,
@@ -817,6 +872,7 @@ def do_stamp(rev, _):
817872 script .run_env ()
818873
819874
875+ @use_tempdir
820876def edit (config : AlembicConfig , rev : str = "current" ) -> None :
821877 """使用 `$EDITOR` 编辑修订文件。
822878
@@ -826,9 +882,6 @@ def edit(config: AlembicConfig, rev: str = "current") -> None:
826882 """
827883
828884 script = ScriptDirectory .from_config (config )
829- temp_version_locations = click .get_current_context ().meta .get (
830- f"{ __name__ } .temp_version_locations"
831- )
832885
833886 if rev == "current" :
834887
0 commit comments