1
+ import inspect # Added import
1
2
import sys
2
3
from typing import TYPE_CHECKING , Any , Optional , TextIO , Union
3
4
4
5
from advanced_alchemy .config .asyncio import SQLAlchemyAsyncConfig
6
+ from advanced_alchemy .exceptions import ImproperConfigurationError
5
7
from alembic import command as migration_command
6
8
from alembic .config import Config as _AlembicCommandConfig
7
9
from alembic .ddl .impl import DefaultImpl
@@ -39,6 +41,7 @@ def __init__(
39
41
version_table_name : str ,
40
42
bind_key : "Optional[str]" = None ,
41
43
file_ : "Union[str, os.PathLike[str], None]" = None ,
44
+ toml_file : "Union[str, os.PathLike[str], None]" = None ,
42
45
ini_section : str = "alembic" ,
43
46
output_buffer : "Optional[TextIO]" = None ,
44
47
stdout : "TextIO" = sys .stdout ,
@@ -57,7 +60,8 @@ def __init__(
57
60
engine (sqlalchemy.engine.Engine | sqlalchemy.ext.asyncio.AsyncEngine): The SQLAlchemy engine instance.
58
61
version_table_name (str): The name of the version table.
59
62
bind_key (str | None): The bind key for the metadata.
60
- file_ (str | os.PathLike[str] | None): The file path for the alembic configuration.
63
+ file_ (str | os.PathLike[str] | None): The file path for the alembic .ini configuration.
64
+ toml_file (str | os.PathLike[str] | None): The file path for the alembic pyproject.toml configuration.
61
65
ini_section (str): The ini section name.
62
66
output_buffer (typing.TextIO | None): The output buffer for alembic commands.
63
67
stdout (typing.TextIO): The standard output stream.
@@ -80,9 +84,33 @@ def __init__(
80
84
self .compare_type = compare_type
81
85
self .engine = engine
82
86
self .db_url = engine .url .render_as_string (hide_password = False )
83
- if config_args is None :
84
- config_args = {}
85
- super ().__init__ (file_ , ini_section , output_buffer , stdout , cmd_opts , config_args , attributes )
87
+
88
+ _config_args = {} if config_args is None else dict (config_args )
89
+
90
+ # Prepare kwargs for super().__init__
91
+ super_init_kwargs : dict [str , Any ] = {
92
+ "file_" : file_ ,
93
+ "ini_section" : ini_section ,
94
+ "output_buffer" : output_buffer ,
95
+ "stdout" : stdout ,
96
+ "cmd_opts" : cmd_opts ,
97
+ "config_args" : _config_args , # Pass the mutable copy
98
+ "attributes" : attributes ,
99
+ }
100
+
101
+ # Inspect the parent class __init__ for toml_file parameter
102
+ parent_init_sig = inspect .signature (super ().__init__ )
103
+ if "toml_file" in parent_init_sig .parameters :
104
+ super_init_kwargs ["toml_file" ] = toml_file
105
+ elif toml_file is not None :
106
+ msg = (
107
+ "The 'toml_file' parameter is not supported by your current Alembic version. "
108
+ "Please upgrade Alembic to 1.16.0 or later to use this feature, "
109
+ "or remove the 'toml_file' argument from AlembicCommandConfig."
110
+ )
111
+ raise ImproperConfigurationError (msg )
112
+
113
+ super ().__init__ (** super_init_kwargs )
86
114
87
115
def get_template_directory (self ) -> str :
88
116
"""Return the directory where Alembic setup templates are found.
@@ -337,6 +365,8 @@ def _get_alembic_command_config(self) -> "AlembicCommandConfig":
337
365
AlembicCommandConfig: The configuration for Alembic commands.
338
366
"""
339
367
kwargs : dict [str , Any ] = {}
368
+ if self .sqlalchemy_config .alembic_config .toml_file :
369
+ kwargs ["toml_file" ] = self .sqlalchemy_config .alembic_config .toml_file
340
370
if self .sqlalchemy_config .alembic_config .script_config :
341
371
kwargs ["file_" ] = self .sqlalchemy_config .alembic_config .script_config
342
372
if self .sqlalchemy_config .alembic_config .template_path :
0 commit comments