Skip to content

Commit 17d1be4

Browse files
committed
Add naming constraint as cli args
1 parent 53de15d commit 17d1be4

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

src/sqlacodegen/cli.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import argparse
44
import sys
55
from contextlib import ExitStack
6-
from typing import TextIO
6+
from typing import Sequence, TextIO
77

88
from sqlalchemy.engine import create_engine
99
from sqlalchemy.schema import MetaData
@@ -14,6 +14,18 @@
1414
from importlib.metadata import entry_points, version
1515

1616

17+
def parse_naming_convs(naming_convs: Sequence[str]) -> dict[str, str]:
18+
d = {}
19+
for naming_conv in naming_convs:
20+
try:
21+
key, value = naming_conv.split("=", 1)
22+
except ValueError:
23+
raise ValueError('Naming convention must be in "key=template" format')
24+
25+
d[key] = value
26+
return d
27+
28+
1729
def main() -> None:
1830
generators = {ep.name: ep for ep in entry_points(group="sqlacodegen.generators")}
1931
parser = argparse.ArgumentParser(
@@ -40,6 +52,13 @@ def main() -> None:
4052
)
4153
parser.add_argument("--noviews", action="store_true", help="ignore views")
4254
parser.add_argument("--outfile", help="file to write output to (default: stdout)")
55+
parser.add_argument(
56+
"--conv",
57+
nargs="*",
58+
help='constraint naming conventions in "key=template" format \
59+
e.g., --conv "pk=pk_%%(table_name)s" "uq=uq_%%(table_name)s_%%(column_0_name)s"',
60+
)
61+
4362
args = parser.parse_args()
4463

4564
if args.version:
@@ -58,6 +77,11 @@ def main() -> None:
5877
for schema in schemas:
5978
metadata.reflect(engine, schema, not args.noviews, tables)
6079

80+
# Naming convention must be added after reflection to
81+
# avoid the token %(constraint_name)s duplicating the name
82+
if args.conv:
83+
metadata.naming_convention = parse_naming_convs(args.conv)
84+
6185
# Instantiate the generator
6286
generator_class = generators[args.generator].load()
6387
generator = generator_class(metadata, engine, set(args.option or ()))

tests/test_cli.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,36 @@ def test_main() -> None:
172172
check=True,
173173
)
174174
assert completed.stdout.decode().strip() == expected_version
175+
176+
177+
@pytest.fixture
178+
def empty_db_path(tmp_path: Path) -> Path:
179+
path = tmp_path / "test.db"
180+
181+
return path
182+
183+
184+
def test_naming_convention(empty_db_path: Path, tmp_path: Path) -> None:
185+
output_path = tmp_path / "outfile"
186+
subprocess.run(
187+
[
188+
"sqlacodegen",
189+
f"sqlite:///{empty_db_path}",
190+
"--outfile",
191+
str(output_path),
192+
"--conv",
193+
"pk=pk_%(table_name)s",
194+
],
195+
check=True,
196+
)
197+
198+
assert (
199+
output_path.read_text()
200+
== """\
201+
from sqlalchemy import MetaData
202+
203+
metadata = MetaData()
204+
metadata.naming_convention = {'pk': 'pk_%(table_name)s'}
205+
206+
"""
207+
)

0 commit comments

Comments
 (0)