33import argparse
44import sys
55from contextlib import ExitStack
6- from typing import TextIO
6+ from typing import Sequence , TextIO
77
88from sqlalchemy .engine import create_engine
99from sqlalchemy .schema import MetaData
1414 from importlib .metadata import entry_points , version
1515
1616
17+ def parse_naming_convs (naming_convs : Sequence [str ]) -> dict [str , str ]:
18+ d = {}
19+ for naming_conv in naming_convs :
20+ try :
21+ key , value = naming_conv .split ("=" , 1 )
22+ except ValueError :
23+ raise ValueError ('Naming convention must be in "key=template" format' )
24+
25+ d [key ] = value
26+ return d
27+
28+
1729def main () -> None :
1830 generators = {ep .name : ep for ep in entry_points (group = "sqlacodegen.generators" )}
1931 parser = argparse .ArgumentParser (
@@ -40,6 +52,13 @@ def main() -> None:
4052 )
4153 parser .add_argument ("--noviews" , action = "store_true" , help = "ignore views" )
4254 parser .add_argument ("--outfile" , help = "file to write output to (default: stdout)" )
55+ parser .add_argument (
56+ "--conv" ,
57+ nargs = "*" ,
58+ help = 'constraint naming conventions in "key=template" format \
59+ e.g., --conv "pk=pk_%%(table_name)s" "uq=uq_%%(table_name)s_%%(column_0_name)s"' ,
60+ )
61+
4362 args = parser .parse_args ()
4463
4564 if args .version :
@@ -58,6 +77,11 @@ def main() -> None:
5877 for schema in schemas :
5978 metadata .reflect (engine , schema , not args .noviews , tables )
6079
80+ # Naming convention must be added after reflection to
81+ # avoid the token %(constraint_name)s duplicating the name
82+ if args .conv :
83+ metadata .naming_convention = parse_naming_convs (args .conv )
84+
6185 # Instantiate the generator
6286 generator_class = generators [args .generator ].load ()
6387 generator = generator_class (metadata , engine , set (args .option or ()))
0 commit comments