11from __future__ import annotations
22
33from pathlib import Path
4- from typing import TYPE_CHECKING , Sequence , cast
4+ from typing import TYPE_CHECKING , Sequence , Union , cast
55
66if TYPE_CHECKING :
77 from click import Group
@@ -43,7 +43,7 @@ def alchemy_group(ctx: click.Context, config: str) -> None:
4343 ctx .ensure_object (dict )
4444 try :
4545 config_instance = module_loader .import_string (config )
46- if isinstance (config_instance , ( list , tuple ) ):
46+ if isinstance (config_instance , Sequence ):
4747 ctx .obj ["configs" ] = config_instance
4848 else :
4949 ctx .obj ["configs" ] = [config_instance ]
@@ -72,138 +72,163 @@ def add_migration_commands(database_group: Group | None = None) -> Group: # noq
7272 if database_group is None :
7373 database_group = get_alchemy_group ()
7474
75+ bind_key_option = click .option (
76+ "--bind-key" ,
77+ help = "Specify which SQLAlchemy config to use by bind key" ,
78+ type = str ,
79+ default = None ,
80+ )
81+ verbose_option = click .option (
82+ "--verbose" ,
83+ help = "Enable verbose output." ,
84+ type = bool ,
85+ default = False ,
86+ is_flag = True ,
87+ )
88+ no_prompt_option = click .option (
89+ "--no-prompt" ,
90+ help = "Do not prompt for confirmation before executing the command." ,
91+ type = bool ,
92+ default = False ,
93+ required = False ,
94+ show_default = True ,
95+ is_flag = True ,
96+ )
97+
98+ def get_config_by_bind_key (
99+ ctx : click .Context , bind_key : str | None
100+ ) -> SQLAlchemyAsyncConfig | SQLAlchemySyncConfig :
101+ """Get the SQLAlchemy config for the specified bind key."""
102+ configs = ctx .obj ["configs" ]
103+ if bind_key is None :
104+ return cast ("Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]" , configs [0 ])
105+
106+ for config in configs :
107+ if config .bind_key == bind_key :
108+ return cast ("Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]" , config )
109+
110+ console .print (f"[red]No config found for bind key: { bind_key } [/]" )
111+ ctx .exit (1 ) # noqa: RET503
112+
75113 @database_group .command (
76114 name = "show-current-revision" ,
77115 help = "Shows the current revision for the database." ,
78116 )
79- @click . option ( "--verbose" , type = bool , help = "Enable verbose output." , default = False , is_flag = True )
80- @click . pass_context
81- def show_database_revision (ctx : click . Context , verbose : bool ) -> None : # pyright: ignore[reportUnusedFunction]
117+ @bind_key_option
118+ @verbose_option
119+ def show_database_revision (bind_key : str | None , verbose : bool ) -> None : # pyright: ignore[reportUnusedFunction]
82120 """Show current database revision."""
83121 from advanced_alchemy .alembic .commands import AlembicCommands
84122
123+ ctx = click .get_current_context ()
85124 console .rule ("[yellow]Listing current revision[/]" , align = "left" )
86- sqlalchemy_config = ctx . obj [ "configs" ][ 0 ]
125+ sqlalchemy_config = get_config_by_bind_key ( ctx , bind_key )
87126 alembic_commands = AlembicCommands (sqlalchemy_config = sqlalchemy_config )
88127 alembic_commands .current (verbose = verbose )
89128
90129 @database_group .command (
91130 name = "downgrade" ,
92131 help = "Downgrade database to a specific revision." ,
93132 )
133+ @bind_key_option
94134 @click .option ("--sql" , type = bool , help = "Generate SQL output for offline migrations." , default = False , is_flag = True )
95135 @click .option (
96136 "--tag" ,
97137 help = "an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method." ,
98138 type = str ,
99139 default = None ,
100140 )
101- @click .option (
102- "--no-prompt" ,
103- help = "Do not prompt for confirmation before downgrading." ,
104- type = bool ,
105- default = False ,
106- required = False ,
107- show_default = True ,
108- is_flag = True ,
109- )
141+ @no_prompt_option
110142 @click .argument (
111143 "revision" ,
112144 type = str ,
113145 default = "-1" ,
114146 )
115- @click .pass_context
116- def downgrade_database (ctx : click .Context , revision : str , sql : bool , tag : str | None , no_prompt : bool ) -> None : # pyright: ignore[reportUnusedFunction]
147+ def downgrade_database ( # pyright: ignore[reportUnusedFunction]
148+ bind_key : str | None , revision : str , sql : bool , tag : str | None , no_prompt : bool
149+ ) -> None :
117150 """Downgrade the database to the latest revision."""
118151 from rich .prompt import Confirm
119152
120153 from advanced_alchemy .alembic .commands import AlembicCommands
121154
155+ ctx = click .get_current_context ()
122156 console .rule ("[yellow]Starting database downgrade process[/]" , align = "left" )
123157 input_confirmed = (
124158 True
125159 if no_prompt
126160 else Confirm .ask (f"Are you sure you want to downgrade the database to the `{ revision } ` revision?" )
127161 )
128162 if input_confirmed :
129- sqlalchemy_config = ctx . obj [ "configs" ][ 0 ]
163+ sqlalchemy_config = get_config_by_bind_key ( ctx , bind_key )
130164 alembic_commands = AlembicCommands (sqlalchemy_config = sqlalchemy_config )
131165 alembic_commands .downgrade (revision = revision , sql = sql , tag = tag )
132166
133167 @database_group .command (
134168 name = "upgrade" ,
135169 help = "Upgrade database to a specific revision." ,
136170 )
171+ @bind_key_option
137172 @click .option ("--sql" , type = bool , help = "Generate SQL output for offline migrations." , default = False , is_flag = True )
138173 @click .option (
139174 "--tag" ,
140175 help = "an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method." ,
141176 type = str ,
142177 default = None ,
143178 )
144- @click .option (
145- "--no-prompt" ,
146- help = "Do not prompt for confirmation before upgrading." ,
147- type = bool ,
148- default = False ,
149- required = False ,
150- show_default = True ,
151- is_flag = True ,
152- )
179+ @no_prompt_option
153180 @click .argument (
154181 "revision" ,
155182 type = str ,
156183 default = "head" ,
157184 )
158- @click .pass_context
159- def upgrade_database (ctx : click .Context , revision : str , sql : bool , tag : str | None , no_prompt : bool ) -> None : # pyright: ignore[reportUnusedFunction]
185+ def upgrade_database ( # pyright: ignore[reportUnusedFunction]
186+ bind_key : str | None , revision : str , sql : bool , tag : str | None , no_prompt : bool
187+ ) -> None :
160188 """Upgrade the database to the latest revision."""
161189 from rich .prompt import Confirm
162190
163191 from advanced_alchemy .alembic .commands import AlembicCommands
164192
193+ ctx = click .get_current_context ()
165194 console .rule ("[yellow]Starting database upgrade process[/]" , align = "left" )
166195 input_confirmed = (
167196 True
168197 if no_prompt
169198 else Confirm .ask (f"[bold]Are you sure you want migrate the database to the `{ revision } ` revision?[/]" )
170199 )
171200 if input_confirmed :
172- sqlalchemy_config = ctx . obj [ "configs" ][ 0 ]
201+ sqlalchemy_config = get_config_by_bind_key ( ctx , bind_key )
173202 alembic_commands = AlembicCommands (sqlalchemy_config = sqlalchemy_config )
174203 alembic_commands .upgrade (revision = revision , sql = sql , tag = tag )
175204
176205 @database_group .command (
177206 name = "init" ,
178207 help = "Initialize migrations for the project." ,
179208 )
209+ @bind_key_option
180210 @click .argument ("directory" , default = None )
181211 @click .option ("--multidb" , is_flag = True , default = False , help = "Support multiple databases" )
182212 @click .option ("--package" , is_flag = True , default = True , help = "Create `__init__.py` for created folder" )
183- @click .option (
184- "--no-prompt" ,
185- help = "Do not prompt for confirmation before initializing." ,
186- type = bool ,
187- default = False ,
188- required = False ,
189- show_default = True ,
190- is_flag = True ,
191- )
192- @click .pass_context
193- def init_alembic (ctx : click .Context , directory : str | None , multidb : bool , package : bool , no_prompt : bool ) -> None : # pyright: ignore[reportUnusedFunction]
213+ @no_prompt_option
214+ def init_alembic ( # pyright: ignore[reportUnusedFunction]
215+ bind_key : str | None , directory : str | None , multidb : bool , package : bool , no_prompt : bool
216+ ) -> None :
194217 """Initialize the database migrations."""
195218 from rich .prompt import Confirm
196219
197220 from advanced_alchemy .alembic .commands import AlembicCommands
198221
222+ ctx = click .get_current_context ()
199223 console .rule ("[yellow]Initializing database migrations." , align = "left" )
200224 input_confirmed = (
201225 True
202226 if no_prompt
203227 else Confirm .ask (f"[bold]Are you sure you want initialize the project in `{ directory } `?[/]" )
204228 )
205229 if input_confirmed :
206- for config in ctx .obj ["configs" ]:
230+ configs = [get_config_by_bind_key (ctx , bind_key )] if bind_key is not None else ctx .obj ["configs" ]
231+ for config in configs :
207232 directory = config .alembic_config .script_location if directory is None else directory
208233 alembic_commands = AlembicCommands (sqlalchemy_config = config )
209234 alembic_commands .init (directory = cast ("str" , directory ), multidb = multidb , package = package )
@@ -212,6 +237,7 @@ def init_alembic(ctx: click.Context, directory: str | None, multidb: bool, packa
212237 name = "make-migrations" ,
213238 help = "Create a new migration revision." ,
214239 )
240+ @bind_key_option
215241 @click .option ("-m" , "--message" , default = None , help = "Revision message" )
216242 @click .option (
217243 "--autogenerate/--no-autogenerate" , default = True , help = "Automatically populate revision with detected changes"
@@ -224,18 +250,9 @@ def init_alembic(ctx: click.Context, directory: str | None, multidb: bool, packa
224250 @click .option ("--branch-label" , default = None , help = "Specify a branch label to apply to the new revision" )
225251 @click .option ("--version-path" , default = None , help = "Specify specific path from config for version file" )
226252 @click .option ("--rev-id" , default = None , help = "Specify a ID to use for revision." )
227- @click .option (
228- "--no-prompt" ,
229- help = "Do not prompt for a migration message." ,
230- type = bool ,
231- default = False ,
232- required = False ,
233- show_default = True ,
234- is_flag = True ,
235- )
236- @click .pass_context
253+ @no_prompt_option
237254 def create_revision ( # pyright: ignore[reportUnusedFunction]
238- ctx : click . Context ,
255+ bind_key : str | None ,
239256 message : str | None ,
240257 autogenerate : bool ,
241258 sql : bool ,
@@ -275,11 +292,12 @@ def process_revision_directives(
275292 )
276293 directives .clear ()
277294
295+ ctx = click .get_current_context ()
278296 console .rule ("[yellow]Starting database upgrade process[/]" , align = "left" )
279297 if message is None :
280298 message = "autogenerated" if no_prompt else Prompt .ask ("Please enter a message describing this revision" )
281299
282- sqlalchemy_config = ctx . obj [ "configs" ][ 0 ]
300+ sqlalchemy_config = get_config_by_bind_key ( ctx , bind_key )
283301 alembic_commands = AlembicCommands (sqlalchemy_config = sqlalchemy_config )
284302 alembic_commands .revision (
285303 message = message ,
@@ -294,24 +312,17 @@ def process_revision_directives(
294312 )
295313
296314 @database_group .command (name = "drop-all" , help = "Drop all tables from the database." )
297- @click .option (
298- "--no-prompt" ,
299- help = "Do not prompt for confirmation before upgrading." ,
300- type = bool ,
301- default = False ,
302- required = False ,
303- show_default = True ,
304- is_flag = True ,
305- )
306- @click .pass_context
307- def drop_all (ctx : click .Context , no_prompt : bool ) -> None : # pyright: ignore[reportUnusedFunction]
315+ @bind_key_option
316+ @no_prompt_option
317+ def drop_all (bind_key : str | None , no_prompt : bool ) -> None : # pyright: ignore[reportUnusedFunction]
308318 """Drop all tables from the database."""
309319 from anyio import run
310320 from rich .prompt import Confirm
311321
312322 from advanced_alchemy .alembic .utils import drop_all
313323 from advanced_alchemy .base import metadata_registry
314324
325+ ctx = click .get_current_context ()
315326 console .rule ("[yellow]Dropping all tables from the database[/]" , align = "left" )
316327 input_confirmed = no_prompt or Confirm .ask (
317328 "[bold red]Are you sure you want to drop all tables from the database?"
@@ -325,9 +336,11 @@ async def _drop_all(
325336 await drop_all (engine , config .alembic_config .version_table_name , metadata_registry .get (config .bind_key ))
326337
327338 if input_confirmed :
328- run (_drop_all , ctx .obj ["configs" ])
339+ configs = [get_config_by_bind_key (ctx , bind_key )] if bind_key is not None else ctx .obj ["configs" ]
340+ run (_drop_all , configs )
329341
330342 @database_group .command (name = "dump-data" , help = "Dump specified tables from the database to JSON files." )
343+ @bind_key_option
331344 @click .option (
332345 "--table" ,
333346 "table_names" ,
@@ -344,15 +357,15 @@ async def _drop_all(
344357 default = Path .cwd () / "fixtures" ,
345358 required = False ,
346359 )
347- @click .pass_context
348- def dump_table_data (ctx : click .Context , table_names : tuple [str , ...], dump_dir : Path ) -> None : # pyright: ignore[reportUnusedFunction]
360+ def dump_table_data (bind_key : str | None , table_names : tuple [str , ...], dump_dir : Path ) -> None : # pyright: ignore[reportUnusedFunction]
349361 """Dump table data to JSON files."""
350362 from anyio import run
351363 from rich .prompt import Confirm
352364
353365 from advanced_alchemy .alembic .utils import dump_tables
354366 from advanced_alchemy .base import metadata_registry , orm_registry
355367
368+ ctx = click .get_current_context ()
356369 all_tables = "*" in table_names
357370
358371 if all_tables and not Confirm .ask (
@@ -361,7 +374,8 @@ def dump_table_data(ctx: click.Context, table_names: tuple[str, ...], dump_dir:
361374 return console .rule ("[red bold]No data was dumped." , style = "red" , align = "left" )
362375
363376 async def _dump_tables () -> None :
364- for config in ctx .obj ["configs" ]:
377+ configs = [get_config_by_bind_key (ctx , bind_key )] if bind_key is not None else ctx .obj ["configs" ]
378+ for config in configs :
365379 target_tables = set (metadata_registry .get (config .bind_key ).tables )
366380
367381 if not all_tables :
0 commit comments