Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 116 additions & 126 deletions tests/integration/test_adapters/test_adbc/test_migrations.py

Large diffs are not rendered by default.

303 changes: 146 additions & 157 deletions tests/integration/test_adapters/test_aiosqlite/test_migrations.py

Large diffs are not rendered by default.

460 changes: 227 additions & 233 deletions tests/integration/test_adapters/test_asyncmy/test_migrations.py

Large diffs are not rendered by default.

786 changes: 376 additions & 410 deletions tests/integration/test_adapters/test_asyncpg/test_migrations.py

Large diffs are not rendered by default.

244 changes: 118 additions & 126 deletions tests/integration/test_adapters/test_duckdb/test_migrations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Integration tests for DuckDB migration workflow."""

import tempfile
from pathlib import Path
from typing import Any

Expand All @@ -12,24 +11,23 @@
pytestmark = pytest.mark.xdist_group("duckdb")


def test_duckdb_migration_full_workflow() -> None:
def test_duckdb_migration_full_workflow(tmp_path: Path) -> None:
"""Test full DuckDB migration workflow: init -> create -> upgrade -> downgrade."""
with tempfile.TemporaryDirectory() as temp_dir:
migration_dir = Path(temp_dir) / "migrations"
db_path = Path(temp_dir) / "test.duckdb"
migration_dir = tmp_path / "migrations"
db_path = tmp_path / "test.duckdb"

config = DuckDBConfig(
pool_config={"database": str(db_path)},
migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"},
)
commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config)
config = DuckDBConfig(
pool_config={"database": str(db_path)},
migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"},
)
commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config)

commands.init(str(migration_dir), package=True)
commands.init(str(migration_dir), package=True)

assert migration_dir.exists()
assert (migration_dir / "__init__.py").exists()
assert migration_dir.exists()
assert (migration_dir / "__init__.py").exists()

migration_content = '''"""Initial schema migration."""
migration_content = '''"""Initial schema migration."""


def up():
Expand All @@ -49,44 +47,43 @@ def down():
return ["DROP TABLE IF EXISTS users"]
'''

migration_file = migration_dir / "0001_create_users.py"
migration_file.write_text(migration_content)
migration_file = migration_dir / "0001_create_users.py"
migration_file.write_text(migration_content)

commands.upgrade()
commands.upgrade()

with config.provide_session() as driver:
result = driver.execute("SELECT table_name FROM information_schema.tables WHERE table_name = 'users'")
assert len(result.data) == 1
with config.provide_session() as driver:
result = driver.execute("SELECT table_name FROM information_schema.tables WHERE table_name = 'users'")
assert len(result.data) == 1

driver.execute("INSERT INTO users (id, name, email) VALUES (?, ?, ?)", (1, "John Doe", "[email protected]"))
driver.execute("INSERT INTO users (id, name, email) VALUES (?, ?, ?)", (1, "John Doe", "[email protected]"))

users_result = driver.execute("SELECT * FROM users")
assert len(users_result.data) == 1
assert users_result.data[0]["name"] == "John Doe"
assert users_result.data[0]["email"] == "[email protected]"
users_result = driver.execute("SELECT * FROM users")
assert len(users_result.data) == 1
assert users_result.data[0]["name"] == "John Doe"
assert users_result.data[0]["email"] == "[email protected]"

commands.downgrade("base")
commands.downgrade("base")

with config.provide_session() as driver:
result = driver.execute("SELECT table_name FROM information_schema.tables WHERE table_name = 'users'")
assert len(result.data) == 0
with config.provide_session() as driver:
result = driver.execute("SELECT table_name FROM information_schema.tables WHERE table_name = 'users'")
assert len(result.data) == 0


def test_duckdb_multiple_migrations_workflow() -> None:
def test_duckdb_multiple_migrations_workflow(tmp_path: Path) -> None:
"""Test DuckDB workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all."""
with tempfile.TemporaryDirectory() as temp_dir:
migration_dir = Path(temp_dir) / "migrations"
db_path = Path(temp_dir) / "test.duckdb"
migration_dir = tmp_path / "migrations"
db_path = tmp_path / "test.duckdb"

config = DuckDBConfig(
pool_config={"database": str(db_path)},
migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"},
)
commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config)
config = DuckDBConfig(
pool_config={"database": str(db_path)},
migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"},
)
commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config)

commands.init(str(migration_dir), package=True)
commands.init(str(migration_dir), package=True)

migration1_content = '''"""Create users table."""
migration1_content = '''"""Create users table."""


def up():
Expand All @@ -105,7 +102,7 @@ def down():
return ["DROP TABLE IF EXISTS users"]
'''

migration2_content = '''"""Create posts table."""
migration2_content = '''"""Create posts table."""


def up():
Expand All @@ -126,66 +123,63 @@ def down():
return ["DROP TABLE IF EXISTS posts"]
'''

(migration_dir / "0001_create_users.py").write_text(migration1_content)
(migration_dir / "0002_create_posts.py").write_text(migration2_content)
(migration_dir / "0001_create_users.py").write_text(migration1_content)
(migration_dir / "0002_create_posts.py").write_text(migration2_content)

commands.upgrade()
commands.upgrade()

with config.provide_session() as driver:
tables_result = driver.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'main' ORDER BY table_name"
)
table_names = [t["table_name"] for t in tables_result.data]
assert "users" in table_names
assert "posts" in table_names
with config.provide_session() as driver:
tables_result = driver.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'main' ORDER BY table_name"
)
table_names = [t["table_name"] for t in tables_result.data]
assert "users" in table_names
assert "posts" in table_names

driver.execute("INSERT INTO users (id, name, email) VALUES (?, ?, ?)", (1, "Author", "[email protected]"))
driver.execute(
"INSERT INTO posts (id, title, content, user_id) VALUES (?, ?, ?, ?)", (1, "My Post", "Post content", 1)
)
driver.execute("INSERT INTO users (id, name, email) VALUES (?, ?, ?)", (1, "Author", "[email protected]"))
driver.execute(
"INSERT INTO posts (id, title, content, user_id) VALUES (?, ?, ?, ?)", (1, "My Post", "Post content", 1)
)

posts_result = driver.execute("SELECT * FROM posts")
assert len(posts_result.data) == 1
assert posts_result.data[0]["title"] == "My Post"
posts_result = driver.execute("SELECT * FROM posts")
assert len(posts_result.data) == 1
assert posts_result.data[0]["title"] == "My Post"

commands.downgrade("0001")
commands.downgrade("0001")

with config.provide_session() as driver:
tables_result = driver.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'main'"
)
table_names = [t["table_name"] for t in tables_result.data]
assert "users" in table_names
assert "posts" not in table_names
with config.provide_session() as driver:
tables_result = driver.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'main'")
table_names = [t["table_name"] for t in tables_result.data]
assert "users" in table_names
assert "posts" not in table_names

commands.downgrade("base")
commands.downgrade("base")

with config.provide_session() as driver:
tables_result = driver.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'main' AND table_name NOT LIKE 'sqlspec_%'"
)
with config.provide_session() as driver:
tables_result = driver.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'main' AND table_name NOT LIKE 'sqlspec_%'"
)

table_names = [t["table_name"] for t in tables_result.data if not t["table_name"].startswith("sqlspec_")]
assert len(table_names) == 0
table_names = [t["table_name"] for t in tables_result.data if not t["table_name"].startswith("sqlspec_")]
assert len(table_names) == 0


def test_duckdb_migration_current_command() -> None:
def test_duckdb_migration_current_command(tmp_path: Path) -> None:
"""Test the current migration command shows correct version for DuckDB."""
with tempfile.TemporaryDirectory() as temp_dir:
migration_dir = Path(temp_dir) / "migrations"
db_path = Path(temp_dir) / "test.duckdb"
migration_dir = tmp_path / "migrations"
db_path = tmp_path / "test.duckdb"

config = DuckDBConfig(
pool_config={"database": str(db_path)},
migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"},
)
commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config)
config = DuckDBConfig(
pool_config={"database": str(db_path)},
migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"},
)
commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config)

commands.init(str(migration_dir), package=True)
commands.init(str(migration_dir), package=True)

commands.current(verbose=False)
commands.current(verbose=False)

migration_content = '''"""Test migration."""
migration_content = '''"""Test migration."""


def up():
Expand All @@ -198,28 +192,27 @@ def down():
return ["DROP TABLE IF EXISTS test_table"]
'''

(migration_dir / "0001_test.py").write_text(migration_content)
(migration_dir / "0001_test.py").write_text(migration_content)

commands.upgrade()
commands.upgrade()

commands.current(verbose=True)
commands.current(verbose=True)


def test_duckdb_migration_error_handling() -> None:
def test_duckdb_migration_error_handling(tmp_path: Path) -> None:
"""Test DuckDB migration error handling."""
with tempfile.TemporaryDirectory() as temp_dir:
migration_dir = Path(temp_dir) / "migrations"
db_path = Path(temp_dir) / "test.duckdb"
migration_dir = tmp_path / "migrations"
db_path = tmp_path / "test.duckdb"

config = DuckDBConfig(
pool_config={"database": str(db_path)},
migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"},
)
commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config)
config = DuckDBConfig(
pool_config={"database": str(db_path)},
migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"},
)
commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config)

commands.init(str(migration_dir), package=True)
commands.init(str(migration_dir), package=True)

migration_content = '''"""Bad migration."""
migration_content = '''"""Bad migration."""


def up():
Expand All @@ -232,30 +225,29 @@ def down():
return []
'''

(migration_dir / "0001_bad.py").write_text(migration_content)
(migration_dir / "0001_bad.py").write_text(migration_content)

commands.upgrade()
commands.upgrade()

with config.provide_session() as driver:
count = driver.select_value("SELECT COUNT(*) FROM sqlspec_migrations")
assert count == 0, f"Expected empty migration table after failed migration, but found {count} records"
with config.provide_session() as driver:
count = driver.select_value("SELECT COUNT(*) FROM sqlspec_migrations")
assert count == 0, f"Expected empty migration table after failed migration, but found {count} records"


def test_duckdb_migration_with_transactions() -> None:
def test_duckdb_migration_with_transactions(tmp_path: Path) -> None:
"""Test DuckDB migrations work properly with transactions."""
with tempfile.TemporaryDirectory() as temp_dir:
migration_dir = Path(temp_dir) / "migrations"
db_path = Path(temp_dir) / "test.duckdb"
migration_dir = tmp_path / "migrations"
db_path = tmp_path / "test.duckdb"

config = DuckDBConfig(
pool_config={"database": str(db_path)},
migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"},
)
commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config)
config = DuckDBConfig(
pool_config={"database": str(db_path)},
migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"},
)
commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config)

commands.init(str(migration_dir), package=True)
commands.init(str(migration_dir), package=True)

migration_content = '''"""Migration with multiple operations."""
migration_content = '''"""Migration with multiple operations."""


def up():
Expand All @@ -275,18 +267,18 @@ def down():
return ["DROP TABLE IF EXISTS customers"]
'''

(migration_dir / "0001_transaction_test.py").write_text(migration_content)
(migration_dir / "0001_transaction_test.py").write_text(migration_content)

commands.upgrade()
commands.upgrade()

with config.provide_session() as driver:
customers_result = driver.execute("SELECT * FROM customers ORDER BY name")
assert len(customers_result.data) == 2
assert customers_result.data[0]["name"] == "Customer 1"
assert customers_result.data[1]["name"] == "Customer 2"
with config.provide_session() as driver:
customers_result = driver.execute("SELECT * FROM customers ORDER BY name")
assert len(customers_result.data) == 2
assert customers_result.data[0]["name"] == "Customer 1"
assert customers_result.data[1]["name"] == "Customer 2"

commands.downgrade("base")
commands.downgrade("base")

with config.provide_session() as driver:
result = driver.execute("SELECT table_name FROM information_schema.tables WHERE table_name = 'customers'")
assert len(result.data) == 0
with config.provide_session() as driver:
result = driver.execute("SELECT table_name FROM information_schema.tables WHERE table_name = 'customers'")
assert len(result.data) == 0
Loading