Skip to content

Commit e4c32c5

Browse files
DOMAIN as JSON bindings fixes (#407)
* Test formatting fix * Added test coverage * test name fix * Support non-default jsons * Update src/sqlacodegen/generators.py Co-authored-by: Alex Grönholm <[email protected]> * PR Fixes * PR Fixes --------- Co-authored-by: Alex Grönholm <[email protected]>
1 parent b01602e commit e4c32c5

File tree

5 files changed

+129
-6
lines changed

5 files changed

+129
-6
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
dist
1414
build
1515
venv*
16+
docker-compose.yaml

CHANGES.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Version history
1010
- Fixed incorrect package name used in ``importlib.metadata.version`` for
1111
``sqlalchemy-citext``, resolving ``PackageNotFoundError`` (PR by @oaimtiaz)
1212
- Prevent double pluralization (PR by @dkratzert)
13+
- Fixes DOMAIN extending JSON/JSONB data types (PR by @sheinbergon)
1314

1415
**3.0.0**
1516

src/sqlacodegen/generators.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
TypeDecorator,
3939
UniqueConstraint,
4040
)
41-
from sqlalchemy.dialects.postgresql import DOMAIN, JSONB
41+
from sqlalchemy.dialects.postgresql import DOMAIN, JSON, JSONB
4242
from sqlalchemy.engine import Connection, Engine
4343
from sqlalchemy.exc import CompileError
4444
from sqlalchemy.sql.elements import TextClause
@@ -222,7 +222,7 @@ def collect_imports_for_column(self, column: Column[Any]) -> None:
222222

223223
if isinstance(column.type, ARRAY):
224224
self.add_import(column.type.item_type.__class__)
225-
elif isinstance(column.type, JSONB):
225+
elif isinstance(column.type, (JSONB, JSON)):
226226
if (
227227
not isinstance(column.type.astext_type, Text)
228228
or column.type.astext_type.length is not None
@@ -499,7 +499,7 @@ def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> s
499499
else:
500500
return render_callable("mapped_column", *args, kwargs=kwargs)
501501

502-
def render_column_type(self, coltype: object) -> str:
502+
def render_column_type(self, coltype: TypeEngine[Any]) -> str:
503503
args = []
504504
kwargs: dict[str, Any] = {}
505505
sig = inspect.signature(coltype.__class__.__init__)
@@ -515,6 +515,17 @@ def render_column_type(self, coltype: object) -> str:
515515
continue
516516

517517
value = getattr(coltype, param.name, missing)
518+
519+
if isinstance(value, (JSONB, JSON)):
520+
# Remove astext_type if it's the default
521+
if (
522+
isinstance(value.astext_type, Text)
523+
and value.astext_type.length is None
524+
):
525+
value.astext_type = None # type: ignore[assignment]
526+
else:
527+
self.add_import(Text)
528+
518529
default = defaults.get(param.name, missing)
519530
if isinstance(value, TextClause):
520531
self.add_literal_import("sqlalchemy", "text")
@@ -547,7 +558,7 @@ def render_column_type(self, coltype: object) -> str:
547558
if (value := getattr(coltype, colname)) is not None:
548559
kwargs[colname] = repr(value)
549560

550-
if isinstance(coltype, JSONB):
561+
if isinstance(coltype, (JSONB, JSON)):
551562
# Remove astext_type if it's the default
552563
if (
553564
isinstance(coltype.astext_type, Text)
@@ -1224,7 +1235,11 @@ def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
12241235
return "".join(pre), column_type, "]" * post_size
12251236

12261237
def render_python_type(column_type: TypeEngine[Any]) -> str:
1227-
python_type = column_type.python_type
1238+
if isinstance(column_type, DOMAIN):
1239+
python_type = column_type.data_type.python_type
1240+
else:
1241+
python_type = column_type.python_type
1242+
12281243
python_type_name = python_type.__name__
12291244
python_type_module = python_type.__module__
12301245
if python_type_module == "builtins":

tests/test_generator_declarative.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import pytest
44
from _pytest.fixtures import FixtureRequest
5-
from sqlalchemy import PrimaryKeyConstraint
5+
from sqlalchemy import BIGINT, PrimaryKeyConstraint
6+
from sqlalchemy.dialects import postgresql
7+
from sqlalchemy.dialects.postgresql import JSON, JSONB
68
from sqlalchemy.engine import Engine
79
from sqlalchemy.schema import (
810
CheckConstraint,
@@ -1592,3 +1594,87 @@ class WithItems(Base):
15921594
str_matrix: Mapped[Optional[list[list[str]]]] = mapped_column(ARRAY(VARCHAR(), dimensions=2))
15931595
""",
15941596
)
1597+
1598+
1599+
@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
1600+
def test_domain_json(generator: CodeGenerator) -> None:
1601+
Table(
1602+
"test_domain_json",
1603+
generator.metadata,
1604+
Column("id", BIGINT, primary_key=True),
1605+
Column(
1606+
"foo",
1607+
postgresql.DOMAIN(
1608+
"domain_json",
1609+
JSON,
1610+
not_null=False,
1611+
),
1612+
nullable=True,
1613+
),
1614+
)
1615+
1616+
validate_code(
1617+
generator.generate(),
1618+
"""\
1619+
from typing import Optional
1620+
1621+
from sqlalchemy import BigInteger
1622+
from sqlalchemy.dialects.postgresql import DOMAIN, JSON
1623+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1624+
1625+
class Base(DeclarativeBase):
1626+
pass
1627+
1628+
1629+
class TestDomainJson(Base):
1630+
__tablename__ = 'test_domain_json'
1631+
1632+
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
1633+
foo: Mapped[Optional[dict]] = mapped_column(DOMAIN('domain_json', JSON(), not_null=False))
1634+
""",
1635+
)
1636+
1637+
1638+
@pytest.mark.parametrize(
1639+
"domain_type",
1640+
[JSONB, JSON],
1641+
)
1642+
def test_domain_non_default_json(
1643+
generator: CodeGenerator,
1644+
domain_type: type[JSON] | type[JSONB],
1645+
) -> None:
1646+
Table(
1647+
"test_domain_json",
1648+
generator.metadata,
1649+
Column("id", BIGINT, primary_key=True),
1650+
Column(
1651+
"foo",
1652+
postgresql.DOMAIN(
1653+
"domain_json",
1654+
domain_type(astext_type=Text(128)),
1655+
not_null=False,
1656+
),
1657+
nullable=True,
1658+
),
1659+
)
1660+
1661+
validate_code(
1662+
generator.generate(),
1663+
f"""\
1664+
from typing import Optional
1665+
1666+
from sqlalchemy import BigInteger, Text
1667+
from sqlalchemy.dialects.postgresql import DOMAIN, {domain_type.__name__}
1668+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1669+
1670+
class Base(DeclarativeBase):
1671+
pass
1672+
1673+
1674+
class TestDomainJson(Base):
1675+
__tablename__ = 'test_domain_json'
1676+
1677+
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
1678+
foo: Mapped[Optional[dict]] = mapped_column(DOMAIN('domain_json', {domain_type.__name__}(astext_type=Text(length=128)), not_null=False))
1679+
""",
1680+
)

tests/test_generator_tables.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,26 @@ def test_jsonb_default(generator: CodeGenerator) -> None:
181181
)
182182

183183

184+
@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
185+
def test_json_default(generator: CodeGenerator) -> None:
186+
Table("simple_items", generator.metadata, Column("json", postgresql.JSON))
187+
188+
validate_code(
189+
generator.generate(),
190+
"""\
191+
from sqlalchemy import Column, JSON, MetaData, Table
192+
193+
metadata = MetaData()
194+
195+
196+
t_simple_items = Table(
197+
'simple_items', metadata,
198+
Column('json', JSON)
199+
)
200+
""",
201+
)
202+
203+
184204
def test_enum_detection(generator: CodeGenerator) -> None:
185205
Table(
186206
"simple_items",

0 commit comments

Comments
 (0)