88from pprint import pformat
99from argparse import Namespace
1010from operator import attrgetter
11+ from itertools import filterfalse
1112from typing import Any , TextIO , cast
1213from tempfile import TemporaryDirectory
1314from configparser import DuplicateSectionError
@@ -180,11 +181,12 @@ def move_script(self, script: Script) -> Path:
180181 return script_path
181182
182183 plugin_name = script_path .parent .name
183- if version_location := self ._plugin_version_locations .get (plugin_name ):
184- pass
185- elif version_location := self ._plugin_version_locations .get ("" ):
186- plugin_name = ""
187- else :
184+ version_location = self ._plugin_version_locations .get (plugin_name )
185+
186+ if not version_location :
187+ version_location = self ._plugin_version_locations .get ("" )
188+
189+ if not version_location :
188190 self .print_stdout (
189191 f'无法找到 { plugin_name or "<default>" } 对应的版本目录, 忽略 "{ script .path } "' ,
190192 fg = "yellow" ,
@@ -406,7 +408,7 @@ def revision(
406408 config: `AlembicConfig` 对象
407409 message: 迁移的描述
408410 sql: 是否以 SQL 的形式输出迁移脚本
409- head: 迁移的基准版本, 提供了 branch_label 时默认为 'base', 否则默认为 'head'
411+ head: 迁移的基准版本, 如果提供了 branch_label 默认为 `branch_label@head`, 否则为主分支的头
410412 splice: 是否将迁移作为一个新的分支的头; 当 `head` 不是一个分支的头时, 此项必须为 `True`
411413 branch_label: 迁移的分支标签
412414 version_path: 存放迁移脚本的目录
@@ -416,24 +418,12 @@ def revision(
416418 """
417419 from . import _plugins
418420
419- if head is None :
420- head = "base" if branch_label else "head"
421-
422- if not version_path and branch_label and (plugin := _plugins .get (branch_label )):
423- version_path = str (
424- config ._temp_dir .joinpath (
425- * map (
426- attrgetter ("name" ),
427- reversed (list (get_parent_plugins (plugin ))),
428- )
429- )
430- )
431- elif version_path :
421+ if version_path :
432422 version_path = Path (version_path ).resolve ()
433423 version_locations = config .get_main_option ("version_locations" , "" )
434424 pathsep = _SPLIT_ON_PATH [config .get_main_option ("version_path_separator" )]
435425
436- if version_path in (
426+ if version_path not in (
437427 Path (path ).resolve () for path in version_locations .split (pathsep )
438428 ):
439429 config .set_main_option (
@@ -442,9 +432,34 @@ def revision(
442432 logger .warning (
443433 f'临时将目录 "{ version_path } " 添加到版本目录中, 请稍后将其添加到 ALEMBIC_VERSION_LOCATIONS 中'
444434 )
435+ elif branch_label and (plugin := _plugins .get (branch_label )):
436+ version_path = config ._temp_dir .joinpath (
437+ * map (
438+ attrgetter ("name" ),
439+ reversed (list (get_parent_plugins (plugin ))),
440+ )
441+ )
442+ else :
443+ version_path = config ._temp_dir
445444
446445 script = ScriptDirectory .from_config (config )
447446
447+ if not head :
448+ if branch_label :
449+ head = f"{ branch_label } @head"
450+ elif len (heads := script .get_heads ()) <= 1 :
451+ head = "head"
452+ else :
453+ try :
454+ head = next (
455+ filterfalse (
456+ attrgetter ("branch_labels" ),
457+ script .get_revisions (heads ),
458+ )
459+ ).revision
460+ except StopIteration :
461+ head = "base"
462+
448463 revision_context = RevisionContext (
449464 config ,
450465 script ,
@@ -455,7 +470,7 @@ def revision(
455470 head = head ,
456471 splice = splice ,
457472 branch_label = branch_label ,
458- version_path = version_path ,
473+ version_path = str ( version_path ) ,
459474 rev_id = rev_id ,
460475 depends_on = depends_on ,
461476 ),
0 commit comments