Skip to content

Commit ed9f5cd

Browse files
authored
feat(cli): add bind-key option to CLI (#339)
Adds a new `bind-key` option to the CLI for specifying which engine configuration to use for migrations.
1 parent 2caeeda commit ed9f5cd

File tree

5 files changed

+101
-89
lines changed

5 files changed

+101
-89
lines changed

advanced_alchemy/cli.py

Lines changed: 83 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from pathlib import Path
4-
from typing import TYPE_CHECKING, Sequence, cast
4+
from typing import TYPE_CHECKING, Sequence, Union, cast
55

66
if TYPE_CHECKING:
77
from click import Group
@@ -43,7 +43,7 @@ def alchemy_group(ctx: click.Context, config: str) -> None:
4343
ctx.ensure_object(dict)
4444
try:
4545
config_instance = module_loader.import_string(config)
46-
if isinstance(config_instance, (list, tuple)):
46+
if isinstance(config_instance, Sequence):
4747
ctx.obj["configs"] = config_instance
4848
else:
4949
ctx.obj["configs"] = [config_instance]
@@ -72,138 +72,163 @@ def add_migration_commands(database_group: Group | None = None) -> Group: # noq
7272
if database_group is None:
7373
database_group = get_alchemy_group()
7474

75+
bind_key_option = click.option(
76+
"--bind-key",
77+
help="Specify which SQLAlchemy config to use by bind key",
78+
type=str,
79+
default=None,
80+
)
81+
verbose_option = click.option(
82+
"--verbose",
83+
help="Enable verbose output.",
84+
type=bool,
85+
default=False,
86+
is_flag=True,
87+
)
88+
no_prompt_option = click.option(
89+
"--no-prompt",
90+
help="Do not prompt for confirmation before executing the command.",
91+
type=bool,
92+
default=False,
93+
required=False,
94+
show_default=True,
95+
is_flag=True,
96+
)
97+
98+
def get_config_by_bind_key(
99+
ctx: click.Context, bind_key: str | None
100+
) -> SQLAlchemyAsyncConfig | SQLAlchemySyncConfig:
101+
"""Get the SQLAlchemy config for the specified bind key."""
102+
configs = ctx.obj["configs"]
103+
if bind_key is None:
104+
return cast("Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]", configs[0])
105+
106+
for config in configs:
107+
if config.bind_key == bind_key:
108+
return cast("Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]", config)
109+
110+
console.print(f"[red]No config found for bind key: {bind_key}[/]")
111+
ctx.exit(1) # noqa: RET503
112+
75113
@database_group.command(
76114
name="show-current-revision",
77115
help="Shows the current revision for the database.",
78116
)
79-
@click.option("--verbose", type=bool, help="Enable verbose output.", default=False, is_flag=True)
80-
@click.pass_context
81-
def show_database_revision(ctx: click.Context, verbose: bool) -> None: # pyright: ignore[reportUnusedFunction]
117+
@bind_key_option
118+
@verbose_option
119+
def show_database_revision(bind_key: str | None, verbose: bool) -> None: # pyright: ignore[reportUnusedFunction]
82120
"""Show current database revision."""
83121
from advanced_alchemy.alembic.commands import AlembicCommands
84122

123+
ctx = click.get_current_context()
85124
console.rule("[yellow]Listing current revision[/]", align="left")
86-
sqlalchemy_config = ctx.obj["configs"][0]
125+
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
87126
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
88127
alembic_commands.current(verbose=verbose)
89128

90129
@database_group.command(
91130
name="downgrade",
92131
help="Downgrade database to a specific revision.",
93132
)
133+
@bind_key_option
94134
@click.option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True)
95135
@click.option(
96136
"--tag",
97137
help="an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method.",
98138
type=str,
99139
default=None,
100140
)
101-
@click.option(
102-
"--no-prompt",
103-
help="Do not prompt for confirmation before downgrading.",
104-
type=bool,
105-
default=False,
106-
required=False,
107-
show_default=True,
108-
is_flag=True,
109-
)
141+
@no_prompt_option
110142
@click.argument(
111143
"revision",
112144
type=str,
113145
default="-1",
114146
)
115-
@click.pass_context
116-
def downgrade_database(ctx: click.Context, revision: str, sql: bool, tag: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
147+
def downgrade_database( # pyright: ignore[reportUnusedFunction]
148+
bind_key: str | None, revision: str, sql: bool, tag: str | None, no_prompt: bool
149+
) -> None:
117150
"""Downgrade the database to the latest revision."""
118151
from rich.prompt import Confirm
119152

120153
from advanced_alchemy.alembic.commands import AlembicCommands
121154

155+
ctx = click.get_current_context()
122156
console.rule("[yellow]Starting database downgrade process[/]", align="left")
123157
input_confirmed = (
124158
True
125159
if no_prompt
126160
else Confirm.ask(f"Are you sure you want to downgrade the database to the `{revision}` revision?")
127161
)
128162
if input_confirmed:
129-
sqlalchemy_config = ctx.obj["configs"][0]
163+
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
130164
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
131165
alembic_commands.downgrade(revision=revision, sql=sql, tag=tag)
132166

133167
@database_group.command(
134168
name="upgrade",
135169
help="Upgrade database to a specific revision.",
136170
)
171+
@bind_key_option
137172
@click.option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True)
138173
@click.option(
139174
"--tag",
140175
help="an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method.",
141176
type=str,
142177
default=None,
143178
)
144-
@click.option(
145-
"--no-prompt",
146-
help="Do not prompt for confirmation before upgrading.",
147-
type=bool,
148-
default=False,
149-
required=False,
150-
show_default=True,
151-
is_flag=True,
152-
)
179+
@no_prompt_option
153180
@click.argument(
154181
"revision",
155182
type=str,
156183
default="head",
157184
)
158-
@click.pass_context
159-
def upgrade_database(ctx: click.Context, revision: str, sql: bool, tag: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
185+
def upgrade_database( # pyright: ignore[reportUnusedFunction]
186+
bind_key: str | None, revision: str, sql: bool, tag: str | None, no_prompt: bool
187+
) -> None:
160188
"""Upgrade the database to the latest revision."""
161189
from rich.prompt import Confirm
162190

163191
from advanced_alchemy.alembic.commands import AlembicCommands
164192

193+
ctx = click.get_current_context()
165194
console.rule("[yellow]Starting database upgrade process[/]", align="left")
166195
input_confirmed = (
167196
True
168197
if no_prompt
169198
else Confirm.ask(f"[bold]Are you sure you want migrate the database to the `{revision}` revision?[/]")
170199
)
171200
if input_confirmed:
172-
sqlalchemy_config = ctx.obj["configs"][0]
201+
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
173202
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
174203
alembic_commands.upgrade(revision=revision, sql=sql, tag=tag)
175204

176205
@database_group.command(
177206
name="init",
178207
help="Initialize migrations for the project.",
179208
)
209+
@bind_key_option
180210
@click.argument("directory", default=None)
181211
@click.option("--multidb", is_flag=True, default=False, help="Support multiple databases")
182212
@click.option("--package", is_flag=True, default=True, help="Create `__init__.py` for created folder")
183-
@click.option(
184-
"--no-prompt",
185-
help="Do not prompt for confirmation before initializing.",
186-
type=bool,
187-
default=False,
188-
required=False,
189-
show_default=True,
190-
is_flag=True,
191-
)
192-
@click.pass_context
193-
def init_alembic(ctx: click.Context, directory: str | None, multidb: bool, package: bool, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
213+
@no_prompt_option
214+
def init_alembic( # pyright: ignore[reportUnusedFunction]
215+
bind_key: str | None, directory: str | None, multidb: bool, package: bool, no_prompt: bool
216+
) -> None:
194217
"""Initialize the database migrations."""
195218
from rich.prompt import Confirm
196219

197220
from advanced_alchemy.alembic.commands import AlembicCommands
198221

222+
ctx = click.get_current_context()
199223
console.rule("[yellow]Initializing database migrations.", align="left")
200224
input_confirmed = (
201225
True
202226
if no_prompt
203227
else Confirm.ask(f"[bold]Are you sure you want initialize the project in `{directory}`?[/]")
204228
)
205229
if input_confirmed:
206-
for config in ctx.obj["configs"]:
230+
configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"]
231+
for config in configs:
207232
directory = config.alembic_config.script_location if directory is None else directory
208233
alembic_commands = AlembicCommands(sqlalchemy_config=config)
209234
alembic_commands.init(directory=cast("str", directory), multidb=multidb, package=package)
@@ -212,6 +237,7 @@ def init_alembic(ctx: click.Context, directory: str | None, multidb: bool, packa
212237
name="make-migrations",
213238
help="Create a new migration revision.",
214239
)
240+
@bind_key_option
215241
@click.option("-m", "--message", default=None, help="Revision message")
216242
@click.option(
217243
"--autogenerate/--no-autogenerate", default=True, help="Automatically populate revision with detected changes"
@@ -224,18 +250,9 @@ def init_alembic(ctx: click.Context, directory: str | None, multidb: bool, packa
224250
@click.option("--branch-label", default=None, help="Specify a branch label to apply to the new revision")
225251
@click.option("--version-path", default=None, help="Specify specific path from config for version file")
226252
@click.option("--rev-id", default=None, help="Specify a ID to use for revision.")
227-
@click.option(
228-
"--no-prompt",
229-
help="Do not prompt for a migration message.",
230-
type=bool,
231-
default=False,
232-
required=False,
233-
show_default=True,
234-
is_flag=True,
235-
)
236-
@click.pass_context
253+
@no_prompt_option
237254
def create_revision( # pyright: ignore[reportUnusedFunction]
238-
ctx: click.Context,
255+
bind_key: str | None,
239256
message: str | None,
240257
autogenerate: bool,
241258
sql: bool,
@@ -275,11 +292,12 @@ def process_revision_directives(
275292
)
276293
directives.clear()
277294

295+
ctx = click.get_current_context()
278296
console.rule("[yellow]Starting database upgrade process[/]", align="left")
279297
if message is None:
280298
message = "autogenerated" if no_prompt else Prompt.ask("Please enter a message describing this revision")
281299

282-
sqlalchemy_config = ctx.obj["configs"][0]
300+
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
283301
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
284302
alembic_commands.revision(
285303
message=message,
@@ -294,24 +312,17 @@ def process_revision_directives(
294312
)
295313

296314
@database_group.command(name="drop-all", help="Drop all tables from the database.")
297-
@click.option(
298-
"--no-prompt",
299-
help="Do not prompt for confirmation before upgrading.",
300-
type=bool,
301-
default=False,
302-
required=False,
303-
show_default=True,
304-
is_flag=True,
305-
)
306-
@click.pass_context
307-
def drop_all(ctx: click.Context, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
315+
@bind_key_option
316+
@no_prompt_option
317+
def drop_all(bind_key: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
308318
"""Drop all tables from the database."""
309319
from anyio import run
310320
from rich.prompt import Confirm
311321

312322
from advanced_alchemy.alembic.utils import drop_all
313323
from advanced_alchemy.base import metadata_registry
314324

325+
ctx = click.get_current_context()
315326
console.rule("[yellow]Dropping all tables from the database[/]", align="left")
316327
input_confirmed = no_prompt or Confirm.ask(
317328
"[bold red]Are you sure you want to drop all tables from the database?"
@@ -325,9 +336,11 @@ async def _drop_all(
325336
await drop_all(engine, config.alembic_config.version_table_name, metadata_registry.get(config.bind_key))
326337

327338
if input_confirmed:
328-
run(_drop_all, ctx.obj["configs"])
339+
configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"]
340+
run(_drop_all, configs)
329341

330342
@database_group.command(name="dump-data", help="Dump specified tables from the database to JSON files.")
343+
@bind_key_option
331344
@click.option(
332345
"--table",
333346
"table_names",
@@ -344,15 +357,15 @@ async def _drop_all(
344357
default=Path.cwd() / "fixtures",
345358
required=False,
346359
)
347-
@click.pass_context
348-
def dump_table_data(ctx: click.Context, table_names: tuple[str, ...], dump_dir: Path) -> None: # pyright: ignore[reportUnusedFunction]
360+
def dump_table_data(bind_key: str | None, table_names: tuple[str, ...], dump_dir: Path) -> None: # pyright: ignore[reportUnusedFunction]
349361
"""Dump table data to JSON files."""
350362
from anyio import run
351363
from rich.prompt import Confirm
352364

353365
from advanced_alchemy.alembic.utils import dump_tables
354366
from advanced_alchemy.base import metadata_registry, orm_registry
355367

368+
ctx = click.get_current_context()
356369
all_tables = "*" in table_names
357370

358371
if all_tables and not Confirm.ask(
@@ -361,7 +374,8 @@ def dump_table_data(ctx: click.Context, table_names: tuple[str, ...], dump_dir:
361374
return console.rule("[red bold]No data was dumped.", style="red", align="left")
362375

363376
async def _dump_tables() -> None:
364-
for config in ctx.obj["configs"]:
377+
configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"]
378+
for config in configs:
365379
target_tables = set(metadata_registry.get(config.bind_key).tables)
366380

367381
if not all_tables:

advanced_alchemy/extensions/litestar/cli.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
from contextlib import suppress
44
from typing import TYPE_CHECKING
55

6+
from litestar.cli._utils import LitestarGroup
7+
8+
from advanced_alchemy.cli import add_migration_commands
9+
610
try:
711
import rich_click as click
812
except ImportError:
913
import click # type: ignore[no-redef]
10-
from litestar.cli._utils import LitestarGroup
11-
12-
from advanced_alchemy.cli import add_migration_commands
1314

1415
if TYPE_CHECKING:
1516
from litestar import Litestar
@@ -18,11 +19,7 @@
1819

1920

2021
def get_database_migration_plugin(app: Litestar) -> SQLAlchemyInitPlugin:
21-
"""Retrieve a database migration plugin from the Litestar application's plugins.
22-
23-
This function attempts to find and return either the SQLAlchemyPlugin or SQLAlchemyInitPlugin.
24-
If neither plugin is found, it raises an ImproperlyConfiguredException.
25-
"""
22+
"""Retrieve a database migration plugin from the Litestar application's plugins."""
2623
from advanced_alchemy.exceptions import ImproperConfigurationError
2724
from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyInitPlugin
2825

@@ -33,10 +30,9 @@ def get_database_migration_plugin(app: Litestar) -> SQLAlchemyInitPlugin:
3330

3431

3532
@click.group(cls=LitestarGroup, name="database")
36-
@click.pass_context
3733
def database_group(ctx: click.Context) -> None:
3834
"""Manage SQLAlchemy database components."""
39-
ctx.obj = get_database_migration_plugin(ctx.obj.app).config
35+
ctx.obj = {"app": ctx.obj, "configs": get_database_migration_plugin(ctx.obj.app).config}
4036

4137

4238
add_migration_commands(database_group)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ sqlite = ["aiosqlite>=0.20.0"]
147147
test = [
148148
"pydantic-extra-types < 2.9.0; python_version < \"3.9\"",
149149
"pydantic-extra-types; python_version >= \"3.9\"",
150+
"rich-click",
150151
"coverage>=7.6.1",
151152
"pytest>=7.4.4",
152153
"pytest-asyncio>=0.23.8",
@@ -468,6 +469,7 @@ exclude = [
468469
include = ["advanced_alchemy"]
469470
pythonVersion = "3.8"
470471
reportUnnecessaryTypeIgnoreComments = true
472+
reportUnusedFunction = false
471473
strict = ["advanced_alchemy/**/*"]
472474
venv = ".venv"
473475
venvPath = "."

0 commit comments

Comments
 (0)