|
17 | 17 |
|
18 | 18 | import os
|
19 | 19 | from pathlib import Path
|
20 |
| -from typing import Any, Generator, List |
| 20 | +from typing import Any, Generator, List, cast |
21 | 21 |
|
22 | 22 | import pyarrow as pa
|
23 | 23 | import pytest
|
24 | 24 | from pydantic_core import ValidationError
|
25 | 25 | from pytest_lazyfixture import lazy_fixture
|
| 26 | +from sqlalchemy import Engine, create_engine, inspect |
26 | 27 | from sqlalchemy.exc import ArgumentError, IntegrityError
|
27 | 28 |
|
28 | 29 | from pyiceberg.catalog import (
|
29 | 30 | Catalog,
|
30 | 31 | load_catalog,
|
31 | 32 | )
|
32 |
| -from pyiceberg.catalog.sql import DEFAULT_ECHO_VALUE, DEFAULT_POOL_PRE_PING_VALUE, SqlCatalog |
| 33 | +from pyiceberg.catalog.sql import ( |
| 34 | + DEFAULT_ECHO_VALUE, |
| 35 | + DEFAULT_POOL_PRE_PING_VALUE, |
| 36 | + IcebergTables, |
| 37 | + SqlCatalog, |
| 38 | + SqlCatalogBaseTable, |
| 39 | +) |
33 | 40 | from pyiceberg.exceptions import (
|
34 | 41 | CommitFailedException,
|
35 | 42 | NamespaceAlreadyExistsError,
|
|
54 | 61 | from pyiceberg.typedef import Identifier
|
55 | 62 | from pyiceberg.types import IntegerType, strtobool
|
56 | 63 |
|
| 64 | +CATALOG_TABLES = [c.__tablename__ for c in SqlCatalogBaseTable.__subclasses__()] |
| 65 | + |
57 | 66 |
|
58 | 67 | @pytest.fixture(scope="module")
|
59 | 68 | def catalog_name() -> str:
|
@@ -132,6 +141,16 @@ def catalog_sqlite(catalog_name: str, warehouse: Path) -> Generator[SqlCatalog,
|
132 | 141 | catalog.destroy_tables()
|
133 | 142 |
|
134 | 143 |
|
| 144 | +@pytest.fixture(scope="module") |
| 145 | +def catalog_uri(warehouse: Path) -> str: |
| 146 | + return f"sqlite:////{warehouse}/sql-catalog.db" |
| 147 | + |
| 148 | + |
| 149 | +@pytest.fixture(scope="module") |
| 150 | +def alchemy_engine(catalog_uri: str) -> Engine: |
| 151 | + return create_engine(catalog_uri) |
| 152 | + |
| 153 | + |
135 | 154 | @pytest.fixture(scope="module")
|
136 | 155 | def catalog_sqlite_without_rowcount(catalog_name: str, warehouse: Path) -> Generator[SqlCatalog, None, None]:
|
137 | 156 | props = {
|
@@ -225,6 +244,69 @@ def test_creation_from_impl(catalog_name: str, warehouse: Path) -> None:
|
225 | 244 | )
|
226 | 245 |
|
227 | 246 |
|
| 247 | +def confirm_no_tables_exist(alchemy_engine: Engine) -> None: |
| 248 | + inspector = inspect(alchemy_engine) |
| 249 | + for c in SqlCatalogBaseTable.__subclasses__(): |
| 250 | + if inspector.has_table(c.__tablename__): |
| 251 | + c.__table__.drop(alchemy_engine) |
| 252 | + |
| 253 | + any_table_exists = any(t for t in inspector.get_table_names() if t in CATALOG_TABLES) |
| 254 | + if any_table_exists: |
| 255 | + pytest.raises(TableAlreadyExistsError, "Tables exist, but should not have been created yet") |
| 256 | + |
| 257 | + |
| 258 | +def confirm_all_tables_exist(catalog: SqlCatalog) -> None: |
| 259 | + all_tables_exists = True |
| 260 | + for t in CATALOG_TABLES: |
| 261 | + if t not in inspect(catalog.engine).get_table_names(): |
| 262 | + all_tables_exists = False |
| 263 | + |
| 264 | + assert isinstance(catalog, SqlCatalog), "Catalog should be a SQLCatalog" |
| 265 | + assert all_tables_exists, "Tables should have been created" |
| 266 | + |
| 267 | + |
| 268 | +def load_catalog_for_catalog_table_creation(catalog_name: str, catalog_uri: str) -> SqlCatalog: |
| 269 | + catalog = load_catalog( |
| 270 | + catalog_name, |
| 271 | + type="sql", |
| 272 | + uri=catalog_uri, |
| 273 | + init_catalog_tables="true", |
| 274 | + ) |
| 275 | + |
| 276 | + return cast(SqlCatalog, catalog) |
| 277 | + |
| 278 | + |
| 279 | +def test_creation_when_no_tables_exist(alchemy_engine: Engine, catalog_name: str, catalog_uri: str) -> None: |
| 280 | + confirm_no_tables_exist(alchemy_engine) |
| 281 | + catalog = load_catalog_for_catalog_table_creation(catalog_name=catalog_name, catalog_uri=catalog_uri) |
| 282 | + confirm_all_tables_exist(catalog) |
| 283 | + |
| 284 | + |
| 285 | +def test_creation_when_one_tables_exists(alchemy_engine: Engine, catalog_name: str, catalog_uri: str) -> None: |
| 286 | + confirm_no_tables_exist(alchemy_engine) |
| 287 | + |
| 288 | + # Create one table |
| 289 | + inspector = inspect(alchemy_engine) |
| 290 | + IcebergTables.__table__.create(bind=alchemy_engine) |
| 291 | + assert IcebergTables.__tablename__ in [t for t in inspector.get_table_names() if t in CATALOG_TABLES] |
| 292 | + |
| 293 | + catalog = load_catalog_for_catalog_table_creation(catalog_name=catalog_name, catalog_uri=catalog_uri) |
| 294 | + confirm_all_tables_exist(catalog) |
| 295 | + |
| 296 | + |
| 297 | +def test_creation_when_all_tables_exists(alchemy_engine: Engine, catalog_name: str, catalog_uri: str) -> None: |
| 298 | + confirm_no_tables_exist(alchemy_engine) |
| 299 | + |
| 300 | + # Create all tables |
| 301 | + inspector = inspect(alchemy_engine) |
| 302 | + SqlCatalogBaseTable.metadata.create_all(bind=alchemy_engine) |
| 303 | + for c in CATALOG_TABLES: |
| 304 | + assert c in [t for t in inspector.get_table_names() if t in CATALOG_TABLES] |
| 305 | + |
| 306 | + catalog = load_catalog_for_catalog_table_creation(catalog_name=catalog_name, catalog_uri=catalog_uri) |
| 307 | + confirm_all_tables_exist(catalog) |
| 308 | + |
| 309 | + |
228 | 310 | @pytest.mark.parametrize(
|
229 | 311 | "catalog",
|
230 | 312 | [
|
|
0 commit comments