diff --git a/CHANGES.rst b/CHANGES.rst index 96623d9a..959634c0 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,11 @@ Version history =============== +**UNRELEASED** + +- Handle SQLAlchemy type with unimplemented python_type as typing.Any (PR by @danplischke) +- Fix SQLModel metadata reference (PR by @danplischke) + **3.1.0** - Type annotations for ARRAY column attributes now include the Python type of diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index 7b4901a7..db4fba0f 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -1243,9 +1243,18 @@ def render_python_type(column_type: TypeEngine[Any]) -> str: if isinstance(column_type, DOMAIN): python_type = column_type.data_type.python_type else: - python_type = column_type.python_type + try: + python_type = column_type.python_type + except NotImplementedError: + self.add_literal_import("typing", "Any") + python_type = Any + + python_type_name = ( + python_type.__name__ + if hasattr(python_type, "__name__") + else python_type._name + ) - python_type_name = python_type.__name__ python_type_module = python_type.__module__ if python_type_module == "builtins": return python_type_name @@ -1435,7 +1444,7 @@ def generate_base(self) -> None: self.base = Base( literal_imports=[], declarations=[], - metadata_ref="", + metadata_ref="SQLModel.metadata", ) def collect_imports(self, models: Iterable[Model]) -> None: diff --git a/tests/test_generator_dataclass.py b/tests/test_generator_dataclass.py index 307f865c..afec050f 100644 --- a/tests/test_generator_dataclass.py +++ b/tests/test_generator_dataclass.py @@ -2,7 +2,7 @@ import pytest from _pytest.fixtures import FixtureRequest -from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.dialects.postgresql import TSVECTOR, UUID from sqlalchemy.engine import Engine from sqlalchemy.schema import Column, ForeignKeyConstraint, MetaData, Table from sqlalchemy.sql.expression import text @@ -267,3 +267,35 @@ class Simple(Base): id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True) """, ) + + +def test_tsvector_missing_python_type(generator: CodeGenerator) -> None: + Table( + "simple_tsvector", + generator.metadata, + Column("id", UUID, primary_key=True), + Column("vector", TSVECTOR), + ) + + validate_code( + generator.generate(), + """\ + from typing import Any, Optional + import typing + import uuid + + from sqlalchemy import UUID + from sqlalchemy.dialects.postgresql import TSVECTOR + from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column + + class Base(MappedAsDataclass, DeclarativeBase): + pass + + + class SimpleTsvector(Base): + __tablename__ = 'simple_tsvector' + + id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True) + vector: Mapped[Optional[typing.Any]] = mapped_column(TSVECTOR) + """, + ) diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py index 931d5965..97b4a4eb 100644 --- a/tests/test_generator_declarative.py +++ b/tests/test_generator_declarative.py @@ -4,7 +4,7 @@ from _pytest.fixtures import FixtureRequest from sqlalchemy import BIGINT, PrimaryKeyConstraint from sqlalchemy.dialects import postgresql -from sqlalchemy.dialects.postgresql import JSON, JSONB +from sqlalchemy.dialects.postgresql import JSON, JSONB, TSVECTOR from sqlalchemy.engine import Engine from sqlalchemy.schema import ( CheckConstraint, @@ -1706,3 +1706,34 @@ class TestDomainJson(Base): foo: Mapped[Optional[dict]] = mapped_column(DOMAIN('domain_json', {domain_type.__name__}(astext_type=Text(length=128)), not_null=False)) """, ) + + +def test_tsvector_missing_python_type(generator: CodeGenerator) -> None: + Table( + "test_tsvector", + generator.metadata, + Column("id", BIGINT, primary_key=True), + Column("vector", TSVECTOR()), + ) + + validate_code( + generator.generate(), + """\ + from typing import Any, Optional + import typing + + from sqlalchemy import BigInteger + from sqlalchemy.dialects.postgresql import TSVECTOR + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + class Base(DeclarativeBase): + pass + + + class TestTsvector(Base): + __tablename__ = 'test_tsvector' + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + vector: Mapped[Optional[typing.Any]] = mapped_column(TSVECTOR) + """, + ) diff --git a/tests/test_generator_sqlmodel.py b/tests/test_generator_sqlmodel.py index 32a736e2..f5b72837 100644 --- a/tests/test_generator_sqlmodel.py +++ b/tests/test_generator_sqlmodel.py @@ -3,6 +3,7 @@ import pytest from _pytest.fixtures import FixtureRequest from sqlalchemy import Uuid +from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy.engine import Engine from sqlalchemy.schema import ( CheckConstraint, @@ -204,3 +205,61 @@ class SimpleUuid(SQLModel, table=True): id: uuid.UUID = Field(sa_column=Column('id', Uuid, primary_key=True)) """, ) + + +def test_tsvector_missing_python_type(generator: CodeGenerator) -> None: + Table( + "simple_tsvector", + generator.metadata, + Column("id", Uuid, primary_key=True), + Column("search", TSVECTOR), + ) + + validate_code( + generator.generate(), + """\ + from typing import Any, Optional + import typing + import uuid + + from sqlalchemy import Column, Uuid + from sqlalchemy.dialects.postgresql import TSVECTOR + from sqlmodel import Field, SQLModel + + class SimpleTsvector(SQLModel, table=True): + __tablename__ = 'simple_tsvector' + + id: uuid.UUID = Field(sa_column=Column('id', Uuid, primary_key=True)) + search: Optional[typing.Any] = Field(default=None, sa_column=Column('search', TSVECTOR)) + """, + ) + + +def test_metadata_ref(generator: CodeGenerator) -> None: + from sqlmodel import SQLModel + + Table( + "metadata_ref_test_table", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + code = generator.generate() + validate_code( + code, + """\ + from sqlalchemy import Column, Integer + from sqlmodel import Field, SQLModel + + class MetadataRefTestTable(SQLModel, table=True): + __tablename__ = 'metadata_ref_test_table' + + id: int = Field(sa_column=Column('id', Integer, primary_key=True)) + """, + ) + + SQLModel.metadata.clear() # clear the metadata to avoid with the tables defined in this test + exec(code, globals()) + + assert len(SQLModel.metadata.tables) == 1 + assert "metadata_ref_test_table" in SQLModel.metadata.tables