Skip to content

Commit 84dcc39

Browse files
andrew222651pre-commit-ci[bot]agronholm
authored
Handle TextClause objects in DOMAIN expressions (#338)
* Handle TextClause objects in DOMAIN expressions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix line length * Update mypy config * Fix types * De-lint * Add test for integer type * Fix test * Handle TextClause objects in DOMAIN expressions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix line length * Update mypy config * Fix types * De-lint * Add test for integer type * Fix test * Ignore no-untyped-call * Ignore no-untyped-call after merge * Convert all TextClause kwargs to text() * Move TextClause handling * Generalize text import * Update src/sqlacodegen/generators.py Co-authored-by: Alex Grönholm <[email protected]> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alex Grönholm <[email protected]>
1 parent d860a67 commit 84dcc39

File tree

5 files changed

+99
-11
lines changed

5 files changed

+99
-11
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ repos:
2828
- id: mypy
2929
additional_dependencies:
3030
- pytest
31-
- "sqlalchemy[mypy] < 2.0"
31+
- "SQLAlchemy >= 2.0.29"
3232

3333
- repo: https://github.com/pre-commit/pygrep-hooks
3434
rev: v1.10.0

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ classifiers = [
2929
]
3030
requires-python = ">=3.9"
3131
dependencies = [
32-
"SQLAlchemy >= 2.0.23",
32+
"SQLAlchemy >= 2.0.29",
3333
"inflect >= 4.0.0",
3434
"importlib_metadata; python_version < '3.10'",
3535
]
@@ -80,6 +80,7 @@ extend-select = [
8080

8181
[tool.mypy]
8282
strict = true
83+
disable_error_code = "no-untyped-call"
8384

8485
[tool.pytest.ini_options]
8586
addopts = "-rsfE --tb=short"

src/sqlacodegen/generators.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from keyword import iskeyword
1414
from pprint import pformat
1515
from textwrap import indent
16-
from typing import Any, ClassVar
16+
from typing import Any, ClassVar, Literal, cast
1717

1818
import inflect
1919
import sqlalchemy
@@ -38,7 +38,7 @@
3838
TypeDecorator,
3939
UniqueConstraint,
4040
)
41-
from sqlalchemy.dialects.postgresql import JSONB
41+
from sqlalchemy.dialects.postgresql import DOMAIN, JSONB
4242
from sqlalchemy.engine import Connection, Engine
4343
from sqlalchemy.exc import CompileError
4444
from sqlalchemy.sql.elements import TextClause
@@ -228,6 +228,8 @@ def collect_imports_for_column(self, column: Column[Any]) -> None:
228228
or column.type.astext_type.length is not None
229229
):
230230
self.add_import(column.type.astext_type)
231+
elif isinstance(column.type, DOMAIN):
232+
self.add_import(column.type.data_type.__class__)
231233

232234
if column.default:
233235
self.add_import(column.default)
@@ -375,7 +377,7 @@ def render_table(self, table: Table) -> str:
375377

376378
args.append(self.render_constraint(constraint))
377379

378-
for index in sorted(table.indexes, key=lambda i: i.name):
380+
for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
379381
# One-column indexes should be rendered as index=True on columns
380382
if len(index.columns) > 1 or not uses_default_name(index):
381383
args.append(self.render_index(index))
@@ -467,7 +469,7 @@ def render_column(
467469

468470
if isinstance(column.server_default, DefaultClause):
469471
kwargs["server_default"] = render_callable(
470-
"text", repr(column.server_default.arg.text)
472+
"text", repr(cast(TextClause, column.server_default.arg).text)
471473
)
472474
elif isinstance(column.server_default, Computed):
473475
expression = str(column.server_default.sqltext)
@@ -514,12 +516,18 @@ def render_column_type(self, coltype: object) -> str:
514516

515517
value = getattr(coltype, param.name, missing)
516518
default = defaults.get(param.name, missing)
519+
if isinstance(value, TextClause):
520+
self.add_literal_import("sqlalchemy", "text")
521+
rendered_value = render_callable("text", repr(value.text))
522+
else:
523+
rendered_value = repr(value)
524+
517525
if value is missing or value == default:
518526
use_kwargs = True
519527
elif use_kwargs:
520-
kwargs[param.name] = repr(value)
528+
kwargs[param.name] = rendered_value
521529
else:
522-
args.append(repr(value))
530+
args.append(rendered_value)
523531

524532
vararg = next(
525533
(
@@ -1072,6 +1080,7 @@ def generate_relationship_name(
10721080
preferred_name = column_names[0][:-3]
10731081

10741082
if "use_inflect" in self.options:
1083+
inflected_name: str | Literal[False]
10751084
if relationship.type in (
10761085
RelationshipType.ONE_TO_MANY,
10771086
RelationshipType.MANY_TO_MANY,
@@ -1166,7 +1175,7 @@ def render_table_args(self, table: Table) -> str:
11661175
args.append(self.render_constraint(constraint))
11671176

11681177
# Render indexes
1169-
for index in sorted(table.indexes, key=lambda i: i.name):
1178+
for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
11701179
if len(index.columns) > 1 or not uses_default_name(index):
11711180
args.append(self.render_index(index))
11721181

src/sqlacodegen/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import re
44
from collections.abc import Mapping
5-
from typing import Any
5+
from typing import Any, Literal, cast
66

77
from sqlalchemy import PrimaryKeyConstraint, UniqueConstraint
88
from sqlalchemy.engine import Connection, Engine
@@ -97,6 +97,7 @@ def uses_default_name(constraint: Constraint | Index) -> bool:
9797
}
9898
)
9999

100+
key: Literal["fk", "pk", "ix", "ck", "uq"]
100101
if isinstance(constraint, Index):
101102
key = "ix"
102103
elif isinstance(constraint, CheckConstraint):
@@ -139,7 +140,10 @@ def uses_default_name(constraint: Constraint | Index) -> bool:
139140
raise TypeError(f"Unknown constraint type: {constraint.__class__.__qualname__}")
140141

141142
try:
142-
convention: str = table.metadata.naming_convention[key]
143+
convention = cast(
144+
Mapping[str, str],
145+
table.metadata.naming_convention,
146+
)[key]
143147
return constraint.name == (convention % values)
144148
except KeyError:
145149
return False

tests/test_generator_tables.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,80 @@ def test_enum_detection(generator: CodeGenerator) -> None:
205205
)
206206

207207

208+
@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
209+
def test_domain_text(generator: CodeGenerator) -> None:
210+
Table(
211+
"simple_items",
212+
generator.metadata,
213+
Column(
214+
"postal_code",
215+
postgresql.DOMAIN(
216+
"us_postal_code",
217+
Text,
218+
constraint_name="valid_us_postal_code",
219+
not_null=False,
220+
check=text("VALUE ~ '^\\d{5}$' OR VALUE ~ '^\\d{5}-\\d{4}$'"),
221+
),
222+
nullable=False,
223+
),
224+
)
225+
226+
validate_code(
227+
generator.generate(),
228+
"""\
229+
from sqlalchemy import Column, MetaData, Table, Text, text
230+
from sqlalchemy.dialects.postgresql import DOMAIN
231+
232+
metadata = MetaData()
233+
234+
235+
t_simple_items = Table(
236+
'simple_items', metadata,
237+
Column('postal_code', DOMAIN('us_postal_code', Text(), \
238+
constraint_name='valid_us_postal_code', not_null=False, \
239+
check=text("VALUE ~ '^\\\\d{5}$' OR VALUE ~ '^\\\\d{5}-\\\\d{4}$'")), nullable=False)
240+
)
241+
""",
242+
)
243+
244+
245+
@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
246+
def test_domain_int(generator: CodeGenerator) -> None:
247+
Table(
248+
"simple_items",
249+
generator.metadata,
250+
Column(
251+
"n",
252+
postgresql.DOMAIN(
253+
"positive_int",
254+
INTEGER,
255+
constraint_name="positive",
256+
not_null=False,
257+
check=text("VALUE > 0"),
258+
),
259+
nullable=False,
260+
),
261+
)
262+
263+
validate_code(
264+
generator.generate(),
265+
"""\
266+
from sqlalchemy import Column, INTEGER, MetaData, Table, text
267+
from sqlalchemy.dialects.postgresql import DOMAIN
268+
269+
metadata = MetaData()
270+
271+
272+
t_simple_items = Table(
273+
'simple_items', metadata,
274+
Column('n', DOMAIN('positive_int', INTEGER(), \
275+
constraint_name='positive', not_null=False, \
276+
check=text('VALUE > 0')), nullable=False)
277+
)
278+
""",
279+
)
280+
281+
208282
@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
209283
def test_column_adaptation(generator: CodeGenerator) -> None:
210284
Table(

0 commit comments

Comments
 (0)