77from pathlib import Path
88from pprint import pformat
99from argparse import Namespace
10+ from operator import attrgetter
1011from typing import Any , TextIO , cast
1112from tempfile import TemporaryDirectory
1213from configparser import DuplicateSectionError
1617import click
1718import alembic
1819import sqlalchemy
19- from nonebot import logger
2020from alembic .config import Config
2121from sqlalchemy .util import asbool
22+ from nonebot import logger , get_plugin
2223from sqlalchemy import MetaData , Connection
2324from alembic .util .editor import open_in_editor
2425from alembic .script import Script , ScriptDirectory
3233 render_python_code ,
3334)
3435
35- from .utils import is_editable , return_progressbar
36+ from .utils import is_editable , get_parent_plugins , return_progressbar
3637
3738if sys .version_info >= (3 , 11 ):
3839 from typing import Self
@@ -91,20 +92,28 @@ def __init__(
9192 ) -> None :
9293 from . import _engines , _metadatas , plugin_config
9394
94- if file_ is None and Path ("alembic.ini" ).is_file ():
95- file_ = "alembic.ini"
95+ self ._exit_stack = ExitStack ()
96+ self ._plugin_version_locations = {}
97+ self ._temp_dir = Path (self ._exit_stack .enter_context (TemporaryDirectory ()))
98+
99+ if file_ is None and isinstance (plugin_config .alembic_config , Path ):
100+ file_ = plugin_config .alembic_config
96101
97102 if plugin_config .alembic_script_location :
98103 script_location = plugin_config .alembic_script_location
99104 elif (
100- Path ("migrations/ env.py" ).is_file ()
101- and Path ("migrations/ script.py.mako" ).is_file ()
105+ Path ("migrations" , " env.py" ).is_file ()
106+ and Path ("migrations" , " script.py.mako" ).is_file ()
102107 ):
103- script_location = "migrations"
108+ script_location = Path ( "migrations" )
104109 elif len (_engines ) == 1 :
105- script_location = str (Path (__file__ ).parent / "templates" / "generic" )
110+ script_location = self ._exit_stack .enter_context (
111+ as_file (files (__name__ ) / "templates" / "generic" )
112+ )
106113 else :
107- script_location = str (Path (__file__ ).parent / "templates" / "multidb" )
114+ script_location = self ._exit_stack .enter_context (
115+ as_file (files (__name__ ) / "templates" / "multidb" )
116+ )
108117
109118 super ().__init__ (
110119 file_ ,
@@ -130,10 +139,6 @@ def __init__(
130139 },
131140 )
132141
133- self ._exit_stack = ExitStack ()
134- self ._plugin_version_locations = {}
135- self ._temp_dir = Path (self ._exit_stack .enter_context (TemporaryDirectory ()))
136-
137142 self ._init_post_write_hooks ()
138143 self ._init_version_locations ()
139144
@@ -173,7 +178,7 @@ def move_script(self, script: Script) -> Path:
173178 except ValueError :
174179 return script_path
175180
176- plugin_name = ( script_path .parent .parts or ( "" ,))[ 0 ]
181+ plugin_name = script_path .parent .name
177182 if version_location := self ._plugin_version_locations .get (plugin_name ):
178183 pass
179184 elif version_location := self ._plugin_version_locations .get ("" ):
@@ -185,12 +190,8 @@ def move_script(self, script: Script) -> Path:
185190 )
186191 return script_path
187192
188- (version_location / script_path .relative_to (plugin_name ).parent ).mkdir (
189- parents = True , exist_ok = True
190- )
191- return shutil .move (
192- script .path , version_location / script_path .relative_to (plugin_name )
193- )
193+ version_location .mkdir (parents = True , exist_ok = True )
194+ return shutil .move (script .path , version_location )
194195
195196 def _add_post_write_hook (self , name : str , ** kwargs : str ) -> None :
196197 self .set_section_option (
@@ -238,16 +239,17 @@ def _init_version_locations(self) -> None:
238239
239240 alembic_version_locations = plugin_config .alembic_version_locations
240241 if isinstance (alembic_version_locations , dict ):
241- main_version_location = Path (
242- alembic_version_locations .get ("" , "migrations/versions" )
243- )
242+ main_version_location = alembic_version_locations .get ("" )
244243 else :
245- main_version_location = Path (
246- alembic_version_locations or "migrations/versions"
247- )
248- self ._plugin_version_locations ["" ] = main_version_location
244+ main_version_location = alembic_version_locations
245+
246+ self ._plugin_version_locations ["" ] = main_version_location or Path (
247+ "migrations" , "versions"
248+ )
249249
250- version_locations = {_data_dir / "migrations" : "" }
250+ temp_version_locations : dict [Path , Path ] = {
251+ _data_dir / "migrations" : self ._temp_dir
252+ }
251253
252254 for plugin in _plugins .values ():
253255 if plugin .metadata and (
@@ -257,39 +259,53 @@ def _init_version_locations(self) -> None:
257259 else :
258260 version_location = files (plugin .module ) / "migrations"
259261
260- if is_editable (plugin ) and isinstance (version_location , Path ):
262+ temp_version_location = Path (
263+ * map (attrgetter ("name" ), reversed (list (get_parent_plugins (plugin )))),
264+ )
265+
266+ if (
267+ not main_version_location
268+ and is_editable (plugin )
269+ and isinstance (version_location , Path )
270+ ):
261271 self ._plugin_version_locations [plugin .name ] = version_location
262272 else :
263273 self ._plugin_version_locations [plugin .name ] = (
264- main_version_location / plugin . name
274+ self . _plugin_version_locations [ "" ] / temp_version_location
265275 )
266276
267- version_locations [
277+ temp_version_locations [
268278 self ._exit_stack .enter_context (as_file (version_location ))
269- ] = plugin . name
279+ ] = ( self . _temp_dir / temp_version_location )
270280
271281 if isinstance (alembic_version_locations , dict ):
272- for name , path in alembic_version_locations .items ():
273- path = self ._plugin_version_locations [name ] = Path (path )
274- version_locations [path ] = name
282+ for plugin_name , version_location in alembic_version_locations .items ():
283+ if not (plugin := get_plugin (plugin_name )):
284+ continue
285+
286+ version_location = Path (version_location )
287+ self ._plugin_version_locations [plugin_name ] = version_location
288+ temp_version_locations [version_location ] = self ._temp_dir .joinpath (
289+ * map (
290+ attrgetter ("name" ),
291+ reversed (list (get_parent_plugins (plugin ))),
292+ )
293+ )
275294
276- version_locations [ main_version_location ] = ""
295+ temp_version_locations [ self . _plugin_version_locations [ "" ]] = self . _temp_dir
277296
278- for src , dst in version_locations .items ():
297+ for src , dst in temp_version_locations .items ():
298+ dst .mkdir (parents = True , exist_ok = True )
279299 with suppress (FileNotFoundError , shutil .Error ):
280- shutil .copytree (src , self . _temp_dir / dst , dirs_exist_ok = True )
300+ shutil .copytree (src , dst , dirs_exist_ok = True )
281301
282302 pathsep = _SPLIT_ON_PATH [self .get_main_option ("version_path_separator" )]
283303 self .set_main_option (
284304 "version_locations" ,
285305 pathsep .join (
286- map (
287- str ,
288- (
289- self ._temp_dir ,
290- * filter (methodcaller ("is_dir" ), self ._temp_dir .iterdir ()),
291- ),
292- )
306+ str (path )
307+ for path in self ._temp_dir .glob ("**" )
308+ if path .name != "__pycache__"
293309 ),
294310 )
295311
@@ -301,11 +317,10 @@ def ignore(path: str, names: list[str]) -> set[str]:
301317 path_ = Path (path )
302318
303319 return set (
304- filter (
305- lambda name : Path (name ).suffix in {".py" , ".pyc" , ".pyo" }
306- and path_ / name not in run_script_path ,
307- names ,
308- )
320+ name
321+ for name in names
322+ if Path (name ).suffix in {".py" , ".pyc" , ".pyo" }
323+ and path_ / name not in run_script_path
309324 )
310325
311326 run_script_path = set (
@@ -377,7 +392,7 @@ def revision(
377392 head : str | None = None ,
378393 splice : bool = False ,
379394 branch_label : str | None = None ,
380- version_path : Path | None = None ,
395+ version_path : str | Path | None = None ,
381396 rev_id : str | None = None ,
382397 depends_on : str | None = None ,
383398 process_revision_directives : ProcessRevisionDirectiveFn | None = None ,
@@ -401,17 +416,30 @@ def revision(
401416 if head is None :
402417 head = "base" if branch_label else "head"
403418
419+ if not version_path and branch_label and (plugin := _plugins .get (branch_label )):
420+ version_path = str (
421+ config ._temp_dir .joinpath (
422+ * map (
423+ attrgetter ("name" ),
424+ reversed (list (get_parent_plugins (plugin ))),
425+ )
426+ )
427+ )
428+
404429 if version_path :
405- version_locations = config .get_main_option ("version_locations" )
430+ version_path = Path (version_path ).resolve ()
431+ version_locations = config .get_main_option ("version_locations" , "" )
406432 pathsep = _SPLIT_ON_PATH [config .get_main_option ("version_path_separator" )]
407- config .set_main_option (
408- "version_locations" , f"{ version_locations } { pathsep } { version_path } "
409- )
410- logger .warning (
411- f'临时将目录 "{ version_path } " 添加到版本目录中, 请稍后将其添加到 ALEMBIC_VERSION_LOCATIONS 中'
412- )
413- elif branch_label in _plugins :
414- version_path = config ._temp_dir / branch_label
433+
434+ if version_path in (
435+ Path (path ).resolve () for path in version_locations .split (pathsep )
436+ ):
437+ config .set_main_option (
438+ "version_locations" , f"{ version_locations } { pathsep } { version_path } "
439+ )
440+ logger .warning (
441+ f'临时将目录 "{ version_path } " 添加到版本目录中, 请稍后将其添加到 ALEMBIC_VERSION_LOCATIONS 中'
442+ )
415443
416444 script = ScriptDirectory .from_config (config )
417445
@@ -740,7 +768,7 @@ def show(config: AlembicConfig, revs: str | Sequence[str] = "current") -> None:
740768 ):
741769 script .run_env ()
742770
743- for sc in cast (Tuple [Script ], script .get_revisions (revs )):
771+ for sc in cast ("tuple [Script]" , script .get_revisions (revs )):
744772 config .print_stdout (sc .log_entry )
745773
746774
@@ -828,7 +856,7 @@ def heads(
828856 else :
829857 heads = script .get_revisions (script .get_heads ())
830858
831- for rev in cast (Tuple [Script ], heads ):
859+ for rev in cast ("tuple [Script]" , heads ):
832860 config .print_stdout (
833861 rev .cmd_format (verbose , include_branches = True , tree_indicators = False )
834862 )
@@ -880,7 +908,7 @@ def display_version(
880908 "Current revision(s) for %s:" ,
881909 cast (Connection , context .connection ).engine .url .render_as_string (),
882910 )
883- for sc in cast (Set [Script ], script .get_all_current (rev )):
911+ for sc in cast ("set [Script]" , script .get_all_current (rev )):
884912 config .print_stdout (sc .cmd_format (verbose ))
885913
886914 return ()
@@ -960,7 +988,7 @@ def edit_current(rev, _) -> Iterable[StampStep | RevisionStep]:
960988 if not rev :
961989 raise click .UsageError ("当前没有迁移" )
962990
963- for sc in cast (Tuple [Script ], script .get_revisions (rev )):
991+ for sc in cast ("tuple [Script]" , script .get_revisions (rev )):
964992 script_path = config .move_script (sc )
965993 open_in_editor (str (script_path ))
966994
@@ -969,12 +997,12 @@ def edit_current(rev, _) -> Iterable[StampStep | RevisionStep]:
969997 with EnvironmentContext (config , script , fn = edit_current ):
970998 script .run_env ()
971999 else :
972- revs = cast (Tuple [Script , ...], script .get_revisions (rev ))
1000+ revs = cast ("tuple [Script, ...]" , script .get_revisions (rev ))
9731001
9741002 if not revs :
9751003 raise click .BadParameter (f'没有 "{ rev } " 指示的迁移脚本' )
9761004
977- for sc in cast (Tuple [Script ], revs ):
1005+ for sc in cast ("tuple [Script]" , revs ):
9781006 script_path = config .move_script (sc )
9791007 open_in_editor (str (script_path ))
9801008
0 commit comments