diff --git a/advanced_alchemy/base.py b/advanced_alchemy/base.py index 47c80ef7..66087319 100644 --- a/advanced_alchemy/base.py +++ b/advanced_alchemy/base.py @@ -3,7 +3,7 @@ import contextlib import datetime import re -from collections.abc import Iterator +from collections.abc import Iterator, Mapping from typing import TYPE_CHECKING, Any, Optional, Protocol, Union, cast, runtime_checkable from uuid import UUID @@ -194,23 +194,162 @@ def to_dict(self, exclude: Optional[set[str]] = None) -> dict[str, Any]: class CommonTableAttributes(BasicAttributes): """Common attributes for SQLAlchemy tables. - Inherits from :class:`BasicAttributes` and provides a mechanism to infer table names from class names. + Inherits from :class:`BasicAttributes` and provides a mechanism to infer table names from class names + while respecting SQLAlchemy's inheritance patterns. + + This mixin supports all three SQLAlchemy inheritance patterns: + - **Single Table Inheritance (STI)**: Child classes automatically use parent's table + - **Joined Table Inheritance (JTI)**: Child classes have their own tables with foreign keys + - **Concrete Table Inheritance (CTI)**: Child classes have independent tables Attributes: - __tablename__ (str): The inferred table name. + __tablename__ (str | None): The inferred table name, or None for Single Table Inheritance children. """ + def __init_subclass__(cls, **kwargs: Any) -> None: + """Hook called when a subclass is created. + + This method intercepts class creation to correctly handle ``__tablename__`` for + Single Table Inheritance (STI) hierarchies. When a parent class explicitly + defines ``__tablename__``, subclasses would normally inherit that string value. + For STI, child classes must have ``__tablename__`` resolve to ``None`` to indicate + they share the parent's table. This hook enforces that rule. + + The detection logic identifies STI children by checking: + 1. Class doesn't explicitly define ``__tablename__`` in its own ``__dict__`` + 2. AND doesn't have ``concrete=True`` (which would make it CTI) + 3. AND doesn't define ``polymorphic_on`` in its own ``__mapper_args__`` (which would make it a base) + 4. AND inherits from a parent that defines ``polymorphic_on`` in ``__mapper_args__`` (STI hierarchy) + + For intermediate classes without ``polymorphic_identity`` but with a parent that has + ``polymorphic_on``, SQLAlchemy can emit a warning. When an intermediate class should + not be instantiated, set ``polymorphic_abstract=True`` in ``__mapper_args__`` or mark it + with ``__abstract__ = True``. + + This allows both usage patterns: + 1. Auto-generated names (don't set ``__tablename__`` on parent) + 2. Explicit names (set ``__tablename__`` on parent, STI still works) + """ + if "__tablename__" in cls.__dict__: + super().__init_subclass__(**kwargs) + return + + cls_dict = cast("Mapping[str, Any]", cls.__dict__) + own_mapper_args = cls_dict.get("__mapper_args__") + own_mapper_args_dict = cast("dict[str, Any]", own_mapper_args) if isinstance(own_mapper_args, dict) else {} + + if own_mapper_args_dict.get("concrete", False): + super().__init_subclass__(**kwargs) + return + + if "polymorphic_on" in own_mapper_args_dict: + super().__init_subclass__(**kwargs) + return + + for parent in cls.__mro__[1:]: + parent_mapper_args = getattr(parent, "__mapper_args__", None) + if isinstance(parent_mapper_args, dict) and "polymorphic_on" in parent_mapper_args: + cls.__tablename__ = None # type: ignore[misc] + break + + super().__init_subclass__(**kwargs) + if TYPE_CHECKING: - __tablename__: str + __tablename__: Optional[str] else: @declared_attr.directive - def __tablename__(cls) -> str: - """Infer table name from class name. + @classmethod + def __tablename__(cls) -> Optional[str]: + """Generate table name automatically for base models. + + This is called for models that do not have an explicit ``__tablename__``. + For STI child models, ``__init_subclass__`` will have already set + ``__tablename__ = None``, so this function returns ``None`` to indicate + the child should use the parent's table. + + The generation logic: + 1. If class explicitly defines ``__tablename__`` in its ``__dict__``, use that + 2. Otherwise, generate from class name using snake_case conversion Returns: - str: The inferred table name. + str | None: Table name generated from class name in snake_case, or None for STI children. + + Example: + Single Table Inheritance (both patterns work):: + + # Pattern 1: Auto-generated table name (recommended) + class Employee(UUIDBase): + # __tablename__ auto-generated as "employee" + type: Mapped[str] + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "employee", + } + + + class Manager(Employee): + # __tablename__ = None (set by __init_subclass__) + department: Mapped[str | None] + __mapper_args__ = {"polymorphic_identity": "manager"} + + + # Pattern 2: Explicit table name on parent + class Employee(UUIDBase): + __tablename__ = "custom_employee" # Explicit! + type: Mapped[str] + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "employee", + } + + + class Manager(Employee): + # __tablename__ = None (set by __init_subclass__) + # Still uses parent's "custom_employee" table + department: Mapped[str | None] + __mapper_args__ = {"polymorphic_identity": "manager"} + + Joined Table Inheritance:: + + class Employee(UUIDBase): + __tablename__ = "employee" + type: Mapped[str] + __mapper_args__ = {"polymorphic_on": "type"} + + + class Manager(Employee): + __tablename__ = "manager" # Explicit - has own table + id: Mapped[int] = mapped_column( + ForeignKey("employee.id"), primary_key=True + ) + department: Mapped[str] + __mapper_args__ = {"polymorphic_identity": "manager"} + + Concrete Table Inheritance:: + + class Employee(UUIDBase): + __tablename__ = "employee" + id: Mapped[int] = mapped_column(primary_key=True) + + + class Manager(Employee): + __tablename__ = "manager" # Independent table + __mapper_args__ = {"concrete": True} """ + cls_dict = cast("Mapping[str, Any]", cls.__dict__) + if "__tablename__" in cls_dict: + return cast("Optional[str]", cls_dict["__tablename__"]) + + mapper_args = getattr(cls, "__mapper_args__", {}) + mapper_args_dict = cast("dict[str, Any]", mapper_args) if isinstance(mapper_args, dict) else {} + if mapper_args_dict.get("concrete", False) or "polymorphic_on" in mapper_args_dict: + return table_name_regexp.sub(r"_\1", cls.__name__).lower() + + for parent in cls.__mro__[1:]: + parent_mapper_args = getattr(parent, "__mapper_args__", None) + if isinstance(parent_mapper_args, dict) and "polymorphic_on" in parent_mapper_args: + return None return table_name_regexp.sub(r"_\1", cls.__name__).lower() diff --git a/docs/changelog.rst b/docs/changelog.rst index 1ed0ec7a..b4e81512 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -21,7 +21,7 @@ Cause: We added the `sa.PasslibHasher = PasslibHasher` and `sa.PwdlibHasher = PwdlibHasher` types in `script.py.mako`. As a result, when a user installs only Advanced Alchemy and creates a migration, these files are imported. Since they reference types from `passlib` and `pwdlib`, which are not installed by default, the import fails and triggers this error. - .. change:: add missing type parameter to AsyncServiceT_co and SyncServiceT_… + .. change:: add missing type parameter to ``AsyncServiceT_co`` and ``SyncServiceT_co`` :type: bugfix :pr: 612 diff --git a/pyproject.toml b/pyproject.toml index f139f929..0a2e0285 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -220,6 +220,7 @@ filterwarnings = [ "ignore::DeprecationWarning:google.gcloud", "ignore::DeprecationWarning:google.iam", "ignore::DeprecationWarning:google", + "ignore:You are using a Python version \\(.*\\) which Google will stop supporting.*:FutureWarning:google.api_core._python_version_support", "ignore::DeprecationWarning:websockets.connection", "ignore::DeprecationWarning:websockets.legacy", "ignore:Accessing argon2.__version__ is deprecated:DeprecationWarning:passlib.handlers.argon2", diff --git a/tests/integration/test_inheritance.py b/tests/integration/test_inheritance.py new file mode 100644 index 00000000..5f400aea --- /dev/null +++ b/tests/integration/test_inheritance.py @@ -0,0 +1,567 @@ +"""Tests for SQLAlchemy inheritance pattern support. + +This module tests all three SQLAlchemy inheritance patterns: +- Single Table Inheritance (STI) +- Joined Table Inheritance (JTI) +- Concrete Table Inheritance (CTI) +""" + +import datetime +from typing import Optional + +import pytest +from sqlalchemy import ForeignKey, MetaData, select +from sqlalchemy.engine import Engine +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +from advanced_alchemy import base + +# ============================================================================ +# Single Table Inheritance (STI) Tests +# ============================================================================ + + +@pytest.mark.integration +def test_sti_basic_table_names() -> None: + """STI: Child classes use parent table name (auto-generated).""" + # Create isolated base with unique metadata + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + # No explicit __tablename__ - let CommonTableAttributes generate it + class STIEmployee(base.CommonTableAttributes, LocalBase): + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] + name: Mapped[str] + __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "employee"} + + class STIManager(STIEmployee): + department: Mapped[Optional[str]] = mapped_column(nullable=True) + __mapper_args__ = {"polymorphic_identity": "manager"} + + class STIEngineer(STIEmployee): + programming_language: Mapped[Optional[str]] = mapped_column(nullable=True) + __mapper_args__ = {"polymorphic_identity": "engineer"} + + # Verify all use same table (auto-generated from parent class name) + expected_name = "sti_employee" # snake_case of STIEmployee + assert STIEmployee.__table__.name == expected_name + assert STIManager.__table__.name == expected_name + assert STIEngineer.__table__.name == expected_name + assert STIManager.__table__ is STIEmployee.__table__ # Same table object + + +@pytest.mark.integration +def test_sti_table_columns() -> None: + """STI: Single table contains all columns from hierarchy.""" + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class Employee(base.CommonTableAttributes, LocalBase): + __tablename__ = "sti_employee_cols" + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] + name: Mapped[str] + __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "employee"} + + class Manager(Employee): + department: Mapped[Optional[str]] = mapped_column(default=None) + __mapper_args__ = {"polymorphic_identity": "manager"} + + class Engineer(Employee): + programming_language: Mapped[Optional[str]] = mapped_column(default=None) + __mapper_args__ = {"polymorphic_identity": "engineer"} + + # Verify columns exist in single table + columns = {col.name for col in Employee.__table__.columns} + assert "type" in columns + assert "name" in columns + assert "department" in columns + assert "programming_language" in columns + + +@pytest.mark.integration +def test_sti_multi_level() -> None: + """STI: Three levels of inheritance share one table.""" + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class Employee(base.CommonTableAttributes, LocalBase): + __tablename__ = "sti_employee_ml" + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] + __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "employee"} + + class Manager(Employee): + department: Mapped[Optional[str]] = mapped_column(default=None) + __mapper_args__ = {"polymorphic_identity": "manager"} + + class SeniorManager(Manager): + budget: Mapped[Optional[int]] = mapped_column(default=None) + __mapper_args__ = {"polymorphic_identity": "senior_manager"} + + # All three levels use same table + assert Employee.__table__.name == "sti_employee_ml" + assert Manager.__table__.name == "sti_employee_ml" + assert SeniorManager.__table__.name == "sti_employee_ml" + + +@pytest.mark.integration +@pytest.mark.sqlite +def test_sti_crud_operations(sqlite_engine: Engine) -> None: + """STI: CRUD operations work correctly with polymorphic models.""" + from sqlalchemy.orm import Session as SessionType + + # Create fresh metadata and registry for this test + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class Employee(base.CommonTableAttributes, LocalBase): + __tablename__ = "sti_employee_crud" + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] + name: Mapped[str] + __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "employee"} + + class Manager(Employee): + department: Mapped[Optional[str]] = mapped_column(default=None) + __mapper_args__ = {"polymorphic_identity": "manager"} + + class Engineer(Employee): + programming_language: Mapped[Optional[str]] = mapped_column(default=None) + __mapper_args__ = {"polymorphic_identity": "engineer"} + + # Create tables + test_metadata.create_all(sqlite_engine) + + try: + with SessionType(sqlite_engine) as test_session: + # Create instances + manager = Manager(name="Alice", department="Engineering", type="manager") + engineer = Engineer(name="Bob", programming_language="Python", type="engineer") + employee = Employee(name="Charlie", type="employee") + + test_session.add_all([manager, engineer, employee]) + test_session.commit() + + # Query all employees + all_employees = test_session.execute(select(Employee)).scalars().all() + assert len(all_employees) == 3 + + # Query specific type + managers = test_session.execute(select(Manager)).scalars().all() + assert len(managers) == 1 + assert isinstance(managers[0], Manager) + assert managers[0].department == "Engineering" + + # Polymorphic identity check + retrieved_manager = test_session.execute(select(Employee).where(Employee.name == "Alice")).scalar_one() + assert isinstance(retrieved_manager, Manager) + assert retrieved_manager.department == "Engineering" + finally: + test_metadata.drop_all(sqlite_engine) + + +# ============================================================================ +# Joined Table Inheritance (JTI) Tests +# ============================================================================ + + +@pytest.mark.integration +def test_jti_basic() -> None: + """JTI: Child has separate table with foreign key.""" + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class Employee(base.CommonTableAttributes, LocalBase): + __tablename__ = "jti_employee" + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] + name: Mapped[str] + __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "employee"} + + class Manager(Employee): + __tablename__ = "jti_manager" + id: Mapped[int] = mapped_column(ForeignKey("jti_employee.id"), primary_key=True) + department: Mapped[str] + __mapper_args__ = {"polymorphic_identity": "manager"} + + # Verify separate tables + assert Employee.__table__.name == "jti_employee" + assert Manager.__table__.name == "jti_manager" + + # Verify foreign key relationship + fk_columns = [fk.parent.name for fk in Manager.__table__.foreign_keys] + assert "id" in fk_columns + + +@pytest.mark.integration +def test_jti_multiple_children() -> None: + """JTI: Multiple children each with own table.""" + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class Employee(base.CommonTableAttributes, LocalBase): + __tablename__ = "jti_employee_multi" + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] + __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "employee"} + + class Manager(Employee): + __tablename__ = "jti_manager_multi" + id: Mapped[int] = mapped_column(ForeignKey("jti_employee_multi.id"), primary_key=True) + department: Mapped[str] + __mapper_args__ = {"polymorphic_identity": "manager"} + + class Engineer(Employee): + __tablename__ = "jti_engineer_multi" + id: Mapped[int] = mapped_column(ForeignKey("jti_employee_multi.id"), primary_key=True) + language: Mapped[str] + __mapper_args__ = {"polymorphic_identity": "engineer"} + + # Three separate tables + assert Employee.__table__.name == "jti_employee_multi" + assert Manager.__table__.name == "jti_manager_multi" + assert Engineer.__table__.name == "jti_engineer_multi" + + +@pytest.mark.integration +@pytest.mark.sqlite +def test_jti_crud_operations(sqlite_engine: Engine) -> None: + """JTI: CRUD operations with joined tables.""" + from sqlalchemy.orm import Session as SessionType + + # Create fresh metadata for this test + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class Employee(base.CommonTableAttributes, LocalBase): + __tablename__ = "jti_employee_crud" + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] + name: Mapped[str] + __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "employee"} + + class Manager(Employee): + __tablename__ = "jti_manager_crud" + id: Mapped[int] = mapped_column(ForeignKey("jti_employee_crud.id"), primary_key=True) + department: Mapped[str] + __mapper_args__ = {"polymorphic_identity": "manager"} + + # Create tables + test_metadata.create_all(sqlite_engine) + + try: + with SessionType(sqlite_engine) as test_session: + # Create instance + manager = Manager(name="Alice", department="Engineering", type="manager") + test_session.add(manager) + test_session.commit() + + # Query + retrieved = test_session.execute(select(Manager)).scalar_one() + assert retrieved.name == "Alice" + assert retrieved.department == "Engineering" + + # Query as base class + as_employee = test_session.execute(select(Employee).where(Employee.name == "Alice")).scalar_one() + assert isinstance(as_employee, Manager) + finally: + test_metadata.drop_all(sqlite_engine) + + +# ============================================================================ +# Concrete Table Inheritance (CTI) Tests +# ============================================================================ + + +@pytest.mark.integration +def test_cti_basic() -> None: + """CTI: Child has independent table (no foreign key).""" + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class Employee(base.CommonTableAttributes, LocalBase): + __tablename__ = "cti_employee" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + class Manager(Employee): + __tablename__ = "cti_manager" + id: Mapped[int] = mapped_column(primary_key=True) + department: Mapped[str] + __mapper_args__ = {"concrete": True} + + # Separate independent tables + assert Employee.__table__.name == "cti_employee" + assert Manager.__table__.name == "cti_manager" + + # No foreign keys + assert len(list(Manager.__table__.foreign_keys)) == 0 + + +@pytest.mark.integration +def test_cti_multiple_concrete_classes() -> None: + """CTI: Multiple concrete subclasses with independent tables.""" + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class Employee(base.CommonTableAttributes, LocalBase): + __tablename__ = "cti_employee_multi" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + class Manager(Employee): + __tablename__ = "cti_manager_multi" + id: Mapped[int] = mapped_column(primary_key=True) + department: Mapped[str] + __mapper_args__ = {"concrete": True} + + class Engineer(Employee): + __tablename__ = "cti_engineer_multi" + id: Mapped[int] = mapped_column(primary_key=True) + language: Mapped[str] + __mapper_args__ = {"concrete": True} + + # All have independent tables + assert Employee.__table__.name == "cti_employee_multi" + assert Manager.__table__.name == "cti_manager_multi" + assert Engineer.__table__.name == "cti_engineer_multi" + + # No foreign keys + assert len(list(Manager.__table__.foreign_keys)) == 0 + assert len(list(Engineer.__table__.foreign_keys)) == 0 + + +@pytest.mark.integration +@pytest.mark.sqlite +def test_cti_crud_operations(sqlite_engine: Engine) -> None: + """CTI: CRUD operations with concrete tables.""" + from sqlalchemy.orm import Session as SessionType + + # Create fresh metadata for this test + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class Employee(base.CommonTableAttributes, LocalBase): + __tablename__ = "cti_employee_crud" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + class Manager(Employee): + __tablename__ = "cti_manager_crud" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] # Must redeclare inherited columns for CTI + department: Mapped[str] + __mapper_args__ = {"concrete": True} + + # Create tables + test_metadata.create_all(sqlite_engine) + + try: + with SessionType(sqlite_engine) as test_session: + # Create instance + manager = Manager(name="Alice", department="Engineering") + test_session.add(manager) + test_session.commit() + + # Query + retrieved = test_session.execute(select(Manager)).scalar_one() + assert retrieved.name == "Alice" + assert retrieved.department == "Engineering" + finally: + test_metadata.drop_all(sqlite_engine) + + +# ============================================================================ +# Edge Case Tests +# ============================================================================ + + +@pytest.mark.integration +def test_explicit_tablename_override() -> None: + """Explicit __tablename__ always respected.""" + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class Employee(base.CommonTableAttributes, LocalBase): + __tablename__ = "employee_explicit" + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] + __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "employee"} + + class Manager(Employee): + __tablename__ = "manager_explicit_override" # Explicit override + id: Mapped[int] = mapped_column(ForeignKey("employee_explicit.id"), primary_key=True) + department: Mapped[str] + __mapper_args__ = {"polymorphic_identity": "manager"} + + # Explicit tablename used (JTI pattern) + assert Manager.__table__.name == "manager_explicit_override" + + +@pytest.mark.integration +def test_mixin_with_inheritance() -> None: + """Mixins don't break inheritance detection.""" + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class TimestampMixin: + created_at: Mapped[datetime.datetime] = mapped_column(default=datetime.datetime.now) + + class Employee(TimestampMixin, base.CommonTableAttributes, LocalBase): + __tablename__ = "employee_mixin" + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] + __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "employee"} + + class Manager(Employee): + department: Mapped[Optional[str]] = mapped_column(default=None) + __mapper_args__ = {"polymorphic_identity": "manager"} + + # STI works despite mixin + assert Manager.__table__.name == "employee_mixin" + + +@pytest.mark.integration +def test_abstract_base_class() -> None: + """Abstract base classes handled correctly.""" + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class BaseEntity(base.CommonTableAttributes, LocalBase): + __abstract__ = True + created_at: Mapped[datetime.datetime] = mapped_column(default=datetime.datetime.now) + + class Employee(BaseEntity): + __tablename__ = "employee_abstract" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + # Abstract base doesn't create table + assert not hasattr(BaseEntity, "__table__") + assert Employee.__table__.name == "employee_abstract" + + +@pytest.mark.integration +def test_no_inheritance_generates_tablename() -> None: + """Classes without inheritance get auto-generated tablename.""" + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class StandaloneModel(base.CommonTableAttributes, LocalBase): + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + # Auto-generated from class name + assert StandaloneModel.__table__.name == "standalone_model" + + +@pytest.mark.integration +@pytest.mark.filterwarnings( + "ignore:Mapper\\[Manager\\(employee_no_poly_id\\)\\] does not indicate a 'polymorphic_identity'.*:" + "sqlalchemy.exc.SAWarning" +) +def test_sti_without_polymorphic_identity_on_child() -> None: + """STI child without explicit polymorphic_identity still uses parent table.""" + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class Employee(base.CommonTableAttributes, LocalBase): + __tablename__ = "employee_no_poly_id" + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] + __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "employee"} + + class Manager(Employee): + department: Mapped[Optional[str]] = mapped_column(default=None) + # No __mapper_args__ - should still detect STI from parent + + # Should use parent table even without explicit polymorphic_identity + assert Manager.__table__.name == "employee_no_poly_id" + + +@pytest.mark.integration +def test_backward_compatibility_simple_models() -> None: + """Existing simple models without inheritance work as before.""" + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class User(base.CommonTableAttributes, LocalBase): + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + email: Mapped[str] + + class Product(base.CommonTableAttributes, LocalBase): + id: Mapped[int] = mapped_column(primary_key=True) + title: Mapped[str] + price: Mapped[int] + + # Auto-generated tablenames still work + assert User.__table__.name == "user" + assert Product.__table__.name == "product" + + +@pytest.mark.integration +def test_sti_with_multiple_inheritance_levels() -> None: + """Multi-level STI inheritance hierarchy.""" + test_metadata = MetaData() + + class LocalBase(DeclarativeBase): + metadata = test_metadata + + class Employee(base.CommonTableAttributes, LocalBase): + __tablename__ = "employee_deep" + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] + __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "employee"} + + class Manager(Employee): + level: Mapped[Optional[int]] = mapped_column(default=None) + __mapper_args__ = {"polymorphic_identity": "manager"} + + class SeniorManager(Manager): + budget: Mapped[Optional[int]] = mapped_column(default=None) + __mapper_args__ = {"polymorphic_identity": "senior_manager"} + + class ExecutiveManager(SeniorManager): + bonus: Mapped[Optional[int]] = mapped_column(default=None) + __mapper_args__ = {"polymorphic_identity": "executive_manager"} + + # All levels use same table + assert Employee.__table__.name == "employee_deep" + assert Manager.__table__.name == "employee_deep" + assert SeniorManager.__table__.name == "employee_deep" + assert ExecutiveManager.__table__.name == "employee_deep" diff --git a/tests/unit/test_base.py b/tests/unit/test_base.py index 249cbeae..63a7720d 100644 --- a/tests/unit/test_base.py +++ b/tests/unit/test_base.py @@ -56,13 +56,13 @@ def test_identity_primary_key_generates_identity_ddl() -> None: @declarative_mixin class TestMixin(IdentityPrimaryKey): - __tablename__ = "test_identity" - - class TestModel(TestMixin, BigIntBase): pass + class IdentityPrimaryKeyModel(TestMixin, BigIntBase): + __tablename__ = "test_identity" + # Get the CREATE TABLE statement - create_stmt = CreateTable(cast(Table, TestModel.__table__)) + create_stmt = CreateTable(cast(Table, IdentityPrimaryKeyModel.__table__)) # Test with PostgreSQL dialect pg_ddl = str(create_stmt.compile(dialect=postgresql.dialect())) # type: ignore[no-untyped-call,unused-ignore] @@ -78,11 +78,11 @@ def test_identity_audit_base_generates_identity_ddl() -> None: """Test that IdentityAuditBase generates proper IDENTITY DDL for PostgreSQL.""" from advanced_alchemy.base import IdentityAuditBase - class TestModel(IdentityAuditBase): + class IdentityAuditBaseModel(IdentityAuditBase): __tablename__ = "test_identity_audit" # Get the CREATE TABLE statement - create_stmt = CreateTable(cast(Table, TestModel.__table__)) + create_stmt = CreateTable(cast(Table, IdentityAuditBaseModel.__table__)) # Test with PostgreSQL dialect pg_ddl = str(create_stmt.compile(dialect=postgresql.dialect())) # type: ignore[no-untyped-call,unused-ignore] @@ -99,13 +99,13 @@ def test_bigint_primary_key_still_uses_sequence() -> None: @declarative_mixin class TestMixin(BigIntPrimaryKey): - __tablename__ = "test_bigint" - - class TestModel(TestMixin, BigIntBase): pass + class BigIntPrimaryKeyModel(TestMixin, BigIntBase): + __tablename__ = "test_bigint" + # Get the CREATE TABLE statement - create_stmt = CreateTable(cast(Table, TestModel.__table__)) + create_stmt = CreateTable(cast(Table, BigIntPrimaryKeyModel.__table__)) # Test with PostgreSQL dialect pg_ddl = str(create_stmt.compile(dialect=postgresql.dialect())) # type: ignore[no-untyped-call,unused-ignore] @@ -114,18 +114,18 @@ class TestModel(TestMixin, BigIntBase): assert "GENERATED" not in pg_ddl assert "IDENTITY" not in pg_ddl.upper() # The sequence is defined on the column but rendered separately - assert TestModel.__table__.c.id.default is not None - assert TestModel.__table__.c.id.default.name == "test_bigint_id_seq" + assert BigIntPrimaryKeyModel.__table__.c.id.default is not None + assert BigIntPrimaryKeyModel.__table__.c.id.default.name == "test_bigint_id_seq" def test_identity_ddl_for_oracle() -> None: """Test Identity DDL generation for Oracle.""" from advanced_alchemy.base import IdentityAuditBase - class TestModel(IdentityAuditBase): + class OracleIdentityAuditBaseModel(IdentityAuditBase): __tablename__ = "test_oracle" - create_stmt = CreateTable(cast(Table, TestModel.__table__)) + create_stmt = CreateTable(cast(Table, OracleIdentityAuditBaseModel.__table__)) oracle_ddl = str(create_stmt.compile(dialect=oracle.dialect())) # type: ignore[no-untyped-call,unused-ignore] # Oracle should generate IDENTITY @@ -136,10 +136,10 @@ def test_identity_ddl_for_mssql() -> None: """Test Identity DDL generation for SQL Server.""" from advanced_alchemy.base import IdentityAuditBase - class TestModel(IdentityAuditBase): + class MSSQLIdentityAuditBaseModel(IdentityAuditBase): __tablename__ = "test_mssql" - create_stmt = CreateTable(cast(Table, TestModel.__table__)) + create_stmt = CreateTable(cast(Table, MSSQLIdentityAuditBaseModel.__table__)) mssql_ddl = str(create_stmt.compile(dialect=mssql.dialect())) # type: ignore[no-untyped-call,unused-ignore] # SQL Server should generate IDENTITY @@ -150,12 +150,12 @@ def test_identity_works_with_sqlite() -> None: """Test that Identity columns work with SQLite (fallback to autoincrement).""" from advanced_alchemy.base import IdentityAuditBase - class TestModel(IdentityAuditBase): + class SQLiteIdentityAuditBaseModel(IdentityAuditBase): __tablename__ = "test_sqlite" # Create an in-memory SQLite engine engine = create_engine("sqlite:///:memory:") - cast(Table, TestModel.__table__).create(engine) + cast(Table, SQLiteIdentityAuditBaseModel.__table__).create(engine) # Should not raise any errors assert True # If we get here, it worked