|
2 | 2 | from __future__ import annotations
|
3 | 3 |
|
4 | 4 | import warnings
|
| 5 | +from typing import cast |
| 6 | + |
| 7 | +from sqlalchemy import Table, create_engine |
| 8 | +from sqlalchemy.dialects import mssql, oracle, postgresql |
| 9 | +from sqlalchemy.orm import declarative_mixin |
| 10 | +from sqlalchemy.schema import CreateTable |
5 | 11 |
|
6 | 12 | from tests.helpers import purge_module
|
7 | 13 |
|
@@ -41,3 +47,115 @@ def test_deprecated_classes_functionality() -> None:
|
41 | 47 | assert hasattr(nanoid_pk, "_sentinel")
|
42 | 48 | assert hasattr(audit, "created_at")
|
43 | 49 | assert hasattr(audit, "updated_at")
|
| 50 | + |
| 51 | + |
| 52 | +def test_identity_primary_key_generates_identity_ddl() -> None: |
| 53 | + """Test that IdentityPrimaryKey generates proper IDENTITY DDL for PostgreSQL.""" |
| 54 | + from advanced_alchemy.base import BigIntBase |
| 55 | + from advanced_alchemy.mixins.bigint import IdentityPrimaryKey |
| 56 | + |
| 57 | + @declarative_mixin |
| 58 | + class TestMixin(IdentityPrimaryKey): |
| 59 | + __tablename__ = "test_identity" |
| 60 | + |
| 61 | + class TestModel(TestMixin, BigIntBase): |
| 62 | + pass |
| 63 | + |
| 64 | + # Get the CREATE TABLE statement |
| 65 | + create_stmt = CreateTable(cast(Table, TestModel.__table__)) |
| 66 | + |
| 67 | + # Test with PostgreSQL dialect |
| 68 | + pg_ddl = str(create_stmt.compile(dialect=postgresql.dialect())) # type: ignore[no-untyped-call,unused-ignore] |
| 69 | + |
| 70 | + # Should contain GENERATED BY DEFAULT AS IDENTITY |
| 71 | + assert "GENERATED BY DEFAULT AS IDENTITY" in pg_ddl |
| 72 | + assert "BIGSERIAL" not in pg_ddl.upper() |
| 73 | + assert "START WITH 1" in pg_ddl |
| 74 | + assert "INCREMENT BY 1" in pg_ddl |
| 75 | + |
| 76 | + |
| 77 | +def test_identity_audit_base_generates_identity_ddl() -> None: |
| 78 | + """Test that IdentityAuditBase generates proper IDENTITY DDL for PostgreSQL.""" |
| 79 | + from advanced_alchemy.base import IdentityAuditBase |
| 80 | + |
| 81 | + class TestModel(IdentityAuditBase): |
| 82 | + __tablename__ = "test_identity_audit" |
| 83 | + |
| 84 | + # Get the CREATE TABLE statement |
| 85 | + create_stmt = CreateTable(cast(Table, TestModel.__table__)) |
| 86 | + |
| 87 | + # Test with PostgreSQL dialect |
| 88 | + pg_ddl = str(create_stmt.compile(dialect=postgresql.dialect())) # type: ignore[no-untyped-call,unused-ignore] |
| 89 | + |
| 90 | + # Should contain GENERATED BY DEFAULT AS IDENTITY |
| 91 | + assert "GENERATED BY DEFAULT AS IDENTITY" in pg_ddl |
| 92 | + assert "BIGSERIAL" not in pg_ddl.upper() |
| 93 | + |
| 94 | + |
| 95 | +def test_bigint_primary_key_still_uses_sequence() -> None: |
| 96 | + """Test that BigIntPrimaryKey still uses sequences as before.""" |
| 97 | + from advanced_alchemy.base import BigIntBase |
| 98 | + from advanced_alchemy.mixins.bigint import BigIntPrimaryKey |
| 99 | + |
| 100 | + @declarative_mixin |
| 101 | + class TestMixin(BigIntPrimaryKey): |
| 102 | + __tablename__ = "test_bigint" |
| 103 | + |
| 104 | + class TestModel(TestMixin, BigIntBase): |
| 105 | + pass |
| 106 | + |
| 107 | + # Get the CREATE TABLE statement |
| 108 | + create_stmt = CreateTable(cast(Table, TestModel.__table__)) |
| 109 | + |
| 110 | + # Test with PostgreSQL dialect |
| 111 | + pg_ddl = str(create_stmt.compile(dialect=postgresql.dialect())) # type: ignore[no-untyped-call,unused-ignore] |
| 112 | + |
| 113 | + # BigIntPrimaryKey should use a Sequence (not IDENTITY) |
| 114 | + assert "GENERATED" not in pg_ddl |
| 115 | + assert "IDENTITY" not in pg_ddl.upper() |
| 116 | + # The sequence is defined on the column but rendered separately |
| 117 | + assert TestModel.__table__.c.id.default is not None |
| 118 | + assert TestModel.__table__.c.id.default.name == "test_bigint_id_seq" |
| 119 | + |
| 120 | + |
| 121 | +def test_identity_ddl_for_oracle() -> None: |
| 122 | + """Test Identity DDL generation for Oracle.""" |
| 123 | + from advanced_alchemy.base import IdentityAuditBase |
| 124 | + |
| 125 | + class TestModel(IdentityAuditBase): |
| 126 | + __tablename__ = "test_oracle" |
| 127 | + |
| 128 | + create_stmt = CreateTable(cast(Table, TestModel.__table__)) |
| 129 | + oracle_ddl = str(create_stmt.compile(dialect=oracle.dialect())) # type: ignore[no-untyped-call,unused-ignore] |
| 130 | + |
| 131 | + # Oracle should generate IDENTITY |
| 132 | + assert "GENERATED BY DEFAULT AS IDENTITY" in oracle_ddl |
| 133 | + |
| 134 | + |
| 135 | +def test_identity_ddl_for_mssql() -> None: |
| 136 | + """Test Identity DDL generation for SQL Server.""" |
| 137 | + from advanced_alchemy.base import IdentityAuditBase |
| 138 | + |
| 139 | + class TestModel(IdentityAuditBase): |
| 140 | + __tablename__ = "test_mssql" |
| 141 | + |
| 142 | + create_stmt = CreateTable(cast(Table, TestModel.__table__)) |
| 143 | + mssql_ddl = str(create_stmt.compile(dialect=mssql.dialect())) # type: ignore[no-untyped-call,unused-ignore] |
| 144 | + |
| 145 | + # SQL Server should generate IDENTITY |
| 146 | + assert "IDENTITY(1,1)" in mssql_ddl |
| 147 | + |
| 148 | + |
| 149 | +def test_identity_works_with_sqlite() -> None: |
| 150 | + """Test that Identity columns work with SQLite (fallback to autoincrement).""" |
| 151 | + from advanced_alchemy.base import IdentityAuditBase |
| 152 | + |
| 153 | + class TestModel(IdentityAuditBase): |
| 154 | + __tablename__ = "test_sqlite" |
| 155 | + |
| 156 | + # Create an in-memory SQLite engine |
| 157 | + engine = create_engine("sqlite:///:memory:") |
| 158 | + cast(Table, TestModel.__table__).create(engine) |
| 159 | + |
| 160 | + # Should not raise any errors |
| 161 | + assert True # If we get here, it worked |
0 commit comments