diff --git a/tests/integration/test_adapters/test_adbc/test_migrations.py b/tests/integration/test_adapters/test_adbc/test_migrations.py index 6319d3dc..4e9f2bf9 100644 --- a/tests/integration/test_adapters/test_adbc/test_migrations.py +++ b/tests/integration/test_adapters/test_adbc/test_migrations.py @@ -1,6 +1,5 @@ """Integration tests for ADBC migration workflow.""" -import tempfile from pathlib import Path from typing import Any @@ -13,24 +12,23 @@ @pytest.mark.xdist_group("sqlite") -def test_adbc_sqlite_migration_full_workflow() -> None: +def test_adbc_sqlite_migration_full_workflow(tmp_path: Path) -> None: """Test full ADBC SQLite migration workflow: init -> create -> upgrade -> downgrade.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.db" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.db" - config = AdbcConfig( - connection_config={"driver_name": "adbc_driver_sqlite", "uri": f"file:{db_path}", "autocommit": True}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + config = AdbcConfig( + connection_config={"driver_name": "adbc_driver_sqlite", "uri": f"file:{db_path}", "autocommit": True}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - assert migration_dir.exists() - assert (migration_dir / "__init__.py").exists() + assert migration_dir.exists() + assert (migration_dir / "__init__.py").exists() - migration_content = '''"""Initial schema migration.""" + migration_content = '''"""Initial schema migration.""" def up(): @@ -50,27 +48,27 @@ def down(): return ["DROP TABLE IF EXISTS users"] ''' - migration_file = migration_dir / "0001_create_users.py" - migration_file.write_text(migration_content) + migration_file = migration_dir / "0001_create_users.py" + migration_file.write_text(migration_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'") - assert len(result.data) == 1 + with config.provide_session() as driver: + result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'") + assert len(result.data) == 1 - driver.execute("INSERT INTO users (name, email) VALUES (?, ?)", ("John Doe", "john@example.com")) + driver.execute("INSERT INTO users (name, email) VALUES (?, ?)", ("John Doe", "john@example.com")) - users_result = driver.execute("SELECT * FROM users") - assert len(users_result.data) == 1 - assert users_result.data[0]["name"] == "John Doe" - assert users_result.data[0]["email"] == "john@example.com" + users_result = driver.execute("SELECT * FROM users") + assert len(users_result.data) == 1 + assert users_result.data[0]["name"] == "John Doe" + assert users_result.data[0]["email"] == "john@example.com" - commands.downgrade("base") + commands.downgrade("base") - with config.provide_session() as driver: - result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'") - assert len(result.data) == 0 + with config.provide_session() as driver: + result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'") + assert len(result.data) == 0 @pytest.mark.xdist_group("postgres") @@ -80,21 +78,20 @@ def test_adbc_postgresql_migration_workflow() -> None: @pytest.mark.xdist_group("sqlite") -def test_adbc_multiple_migrations_workflow() -> None: +def test_adbc_multiple_migrations_workflow(tmp_path: Path) -> None: """Test ADBC workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.db" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.db" - config = AdbcConfig( - connection_config={"driver_name": "adbc_driver_sqlite", "uri": f"file:{db_path}", "autocommit": True}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + config = AdbcConfig( + connection_config={"driver_name": "adbc_driver_sqlite", "uri": f"file:{db_path}", "autocommit": True}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - migration1_content = '''"""Create users table.""" + migration1_content = '''"""Create users table.""" def up(): @@ -113,7 +110,7 @@ def down(): return ["DROP TABLE IF EXISTS users"] ''' - migration2_content = '''"""Create posts table.""" + migration2_content = '''"""Create posts table.""" def up(): @@ -134,63 +131,58 @@ def down(): return ["DROP TABLE IF EXISTS posts"] ''' - (migration_dir / "0001_create_users.py").write_text(migration1_content) - (migration_dir / "0002_create_posts.py").write_text(migration2_content) + (migration_dir / "0001_create_users.py").write_text(migration1_content) + (migration_dir / "0002_create_posts.py").write_text(migration2_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - tables_result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") - table_names = [t["name"] for t in tables_result.data] - assert "users" in table_names - assert "posts" in table_names + with config.provide_session() as driver: + tables_result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") + table_names = [t["name"] for t in tables_result.data] + assert "users" in table_names + assert "posts" in table_names - driver.execute("INSERT INTO users (name, email) VALUES (?, ?)", ("Author", "author@example.com")) - driver.execute( - "INSERT INTO posts (title, content, user_id) VALUES (?, ?, ?)", ("My Post", "Post content", 1) - ) + driver.execute("INSERT INTO users (name, email) VALUES (?, ?)", ("Author", "author@example.com")) + driver.execute("INSERT INTO posts (title, content, user_id) VALUES (?, ?, ?)", ("My Post", "Post content", 1)) - posts_result = driver.execute("SELECT * FROM posts") - assert len(posts_result.data) == 1 - assert posts_result.data[0]["title"] == "My Post" + posts_result = driver.execute("SELECT * FROM posts") + assert len(posts_result.data) == 1 + assert posts_result.data[0]["title"] == "My Post" - commands.downgrade("0001") + commands.downgrade("0001") - with config.provide_session() as driver: - tables_result = driver.execute("SELECT name FROM sqlite_master WHERE type='table'") - table_names = [t["name"] for t in tables_result.data] - assert "users" in table_names - assert "posts" not in table_names + with config.provide_session() as driver: + tables_result = driver.execute("SELECT name FROM sqlite_master WHERE type='table'") + table_names = [t["name"] for t in tables_result.data] + assert "users" in table_names + assert "posts" not in table_names - commands.downgrade("base") + commands.downgrade("base") - with config.provide_session() as driver: - tables_result = driver.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'" - ) + with config.provide_session() as driver: + tables_result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") - table_names = [t["name"] for t in tables_result.data if not t["name"].startswith("sqlspec_")] - assert len(table_names) == 0 + table_names = [t["name"] for t in tables_result.data if not t["name"].startswith("sqlspec_")] + assert len(table_names) == 0 @pytest.mark.xdist_group("sqlite") -def test_adbc_migration_current_command() -> None: +def test_adbc_migration_current_command(tmp_path: Path) -> None: """Test the current migration command shows correct version for ADBC.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.db" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.db" - config = AdbcConfig( - connection_config={"driver_name": "adbc_driver_sqlite", "uri": f"file:{db_path}", "autocommit": True}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + config = AdbcConfig( + connection_config={"driver_name": "adbc_driver_sqlite", "uri": f"file:{db_path}", "autocommit": True}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - commands.current(verbose=False) + commands.current(verbose=False) - migration_content = '''"""Test migration.""" + migration_content = '''"""Test migration.""" def up(): @@ -203,29 +195,28 @@ def down(): return ["DROP TABLE IF EXISTS test_table"] ''' - (migration_dir / "0001_test.py").write_text(migration_content) + (migration_dir / "0001_test.py").write_text(migration_content) - commands.upgrade() + commands.upgrade() - commands.current(verbose=True) + commands.current(verbose=True) @pytest.mark.xdist_group("sqlite") -def test_adbc_migration_error_handling() -> None: +def test_adbc_migration_error_handling(tmp_path: Path) -> None: """Test ADBC migration error handling.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.db" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.db" - config = AdbcConfig( - connection_config={"driver_name": "adbc_driver_sqlite", "uri": f"file:{db_path}", "autocommit": True}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + config = AdbcConfig( + connection_config={"driver_name": "adbc_driver_sqlite", "uri": f"file:{db_path}", "autocommit": True}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - migration_content = '''"""Bad migration.""" + migration_content = '''"""Bad migration.""" def up(): @@ -238,35 +229,34 @@ def down(): return [] ''' - (migration_dir / "0001_bad.py").write_text(migration_content) + (migration_dir / "0001_bad.py").write_text(migration_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - try: - driver.execute("SELECT version FROM sqlspec_migrations ORDER BY version") - msg = "Expected migration table to not exist, but it does" - raise AssertionError(msg) - except Exception as e: - assert "no such" in str(e).lower() or "does not exist" in str(e).lower() + with config.provide_session() as driver: + try: + driver.execute("SELECT version FROM sqlspec_migrations ORDER BY version") + msg = "Expected migration table to not exist, but it does" + raise AssertionError(msg) + except Exception as e: + assert "no such" in str(e).lower() or "does not exist" in str(e).lower() @pytest.mark.xdist_group("sqlite") -def test_adbc_migration_with_transactions() -> None: +def test_adbc_migration_with_transactions(tmp_path: Path) -> None: """Test ADBC migrations work properly with transactions.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.db" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.db" - config = AdbcConfig( - connection_config={"driver_name": "adbc_driver_sqlite", "uri": f"file:{db_path}", "autocommit": True}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + config = AdbcConfig( + connection_config={"driver_name": "adbc_driver_sqlite", "uri": f"file:{db_path}", "autocommit": True}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - migration_content = '''"""Migration with multiple operations.""" + migration_content = '''"""Migration with multiple operations.""" def up(): @@ -286,18 +276,18 @@ def down(): return ["DROP TABLE IF EXISTS customers"] ''' - (migration_dir / "0001_transaction_test.py").write_text(migration_content) + (migration_dir / "0001_transaction_test.py").write_text(migration_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - customers_result = driver.execute("SELECT * FROM customers ORDER BY name") - assert len(customers_result.data) == 2 - assert customers_result.data[0]["name"] == "Customer 1" - assert customers_result.data[1]["name"] == "Customer 2" + with config.provide_session() as driver: + customers_result = driver.execute("SELECT * FROM customers ORDER BY name") + assert len(customers_result.data) == 2 + assert customers_result.data[0]["name"] == "Customer 1" + assert customers_result.data[1]["name"] == "Customer 2" - commands.downgrade("base") + commands.downgrade("base") - with config.provide_session() as driver: - result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='customers'") - assert len(result.data) == 0 + with config.provide_session() as driver: + result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='customers'") + assert len(result.data) == 0 diff --git a/tests/integration/test_adapters/test_aiosqlite/test_migrations.py b/tests/integration/test_adapters/test_aiosqlite/test_migrations.py index 52a13cd9..ecd024e1 100644 --- a/tests/integration/test_adapters/test_aiosqlite/test_migrations.py +++ b/tests/integration/test_adapters/test_aiosqlite/test_migrations.py @@ -1,6 +1,5 @@ """Integration tests for AioSQLite migration workflow.""" -import tempfile from pathlib import Path import pytest @@ -11,29 +10,28 @@ pytestmark = pytest.mark.xdist_group("sqlite") -async def test_aiosqlite_migration_full_workflow() -> None: +async def test_aiosqlite_migration_full_workflow(tmp_path: Path) -> None: """Test full AioSQLite migration workflow: init -> create -> upgrade -> downgrade.""" test_id = "aiosqlite_full_workflow" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.db" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.db" - config = AiosqliteConfig( - pool_config={"database": str(db_path)}, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + config = AiosqliteConfig( + pool_config={"database": str(db_path)}, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - await commands.init(str(migration_dir), package=True) + await commands.init(str(migration_dir), package=True) - assert migration_dir.exists() - assert (migration_dir / "__init__.py").exists() + assert migration_dir.exists() + assert (migration_dir / "__init__.py").exists() - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -53,38 +51,34 @@ def down(): return ["DROP TABLE IF EXISTS {users_table}"] ''' - migration_file = migration_dir / "0001_create_users.py" - migration_file.write_text(migration_content) + migration_file = migration_dir / "0001_create_users.py" + migration_file.write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - async with config.provide_session() as driver: - result = await driver.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{users_table}'") - assert len(result.data) == 1 + async with config.provide_session() as driver: + result = await driver.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{users_table}'") + assert len(result.data) == 1 - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (?, ?)", ("John Doe", "john@example.com") - ) + await driver.execute(f"INSERT INTO {users_table} (name, email) VALUES (?, ?)", ("John Doe", "john@example.com")) - users_result = await driver.execute(f"SELECT * FROM {users_table}") - assert len(users_result.data) == 1 - assert users_result.data[0]["name"] == "John Doe" - assert users_result.data[0]["email"] == "john@example.com" + users_result = await driver.execute(f"SELECT * FROM {users_table}") + assert len(users_result.data) == 1 + assert users_result.data[0]["name"] == "John Doe" + assert users_result.data[0]["email"] == "john@example.com" - try: - await commands.downgrade("base") + try: + await commands.downgrade("base") - async with config.provide_session() as driver: - result = await driver.execute( - f"SELECT name FROM sqlite_master WHERE type='table' AND name='{users_table}'" - ) - assert len(result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + result = await driver.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{users_table}'") + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() -async def test_aiosqlite_multiple_migrations_workflow() -> None: +async def test_aiosqlite_multiple_migrations_workflow(tmp_path: Path) -> None: """Test AioSQLite workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all.""" test_id = "aiosqlite_multiple_workflow" @@ -92,19 +86,18 @@ async def test_aiosqlite_multiple_migrations_workflow() -> None: users_table = f"users_{test_id}" posts_table = f"posts_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.db" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.db" - config = AiosqliteConfig( - pool_config={"database": str(db_path)}, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + config = AiosqliteConfig( + pool_config={"database": str(db_path)}, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - await commands.init(str(migration_dir), package=True) + await commands.init(str(migration_dir), package=True) - migration1_content = f'''"""Create users table.""" + migration1_content = f'''"""Create users table.""" def up(): @@ -123,7 +116,7 @@ def down(): return ["DROP TABLE IF EXISTS {users_table}"] ''' - migration2_content = f'''"""Create posts table.""" + migration2_content = f'''"""Create posts table.""" def up(): @@ -144,75 +137,73 @@ def down(): return ["DROP TABLE IF EXISTS {posts_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration1_content) - (migration_dir / "0002_create_posts.py").write_text(migration2_content) + (migration_dir / "0001_create_users.py").write_text(migration1_content) + (migration_dir / "0002_create_posts.py").write_text(migration2_content) - try: - await commands.upgrade() + try: + await commands.upgrade() - async with config.provide_session() as driver: - tables_result = await driver.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") - table_names = [t["name"] for t in tables_result.data] - assert users_table in table_names - assert posts_table in table_names + async with config.provide_session() as driver: + tables_result = await driver.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") + table_names = [t["name"] for t in tables_result.data] + assert users_table in table_names + assert posts_table in table_names - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (?, ?)", ("Author", "author@example.com") - ) - await driver.execute( - f"INSERT INTO {posts_table} (title, content, user_id) VALUES (?, ?, ?)", - ("My Post", "Post content", 1), - ) + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (?, ?)", ("Author", "author@example.com") + ) + await driver.execute( + f"INSERT INTO {posts_table} (title, content, user_id) VALUES (?, ?, ?)", ("My Post", "Post content", 1) + ) - posts_result = await driver.execute(f"SELECT * FROM {posts_table}") - assert len(posts_result.data) == 1 - assert posts_result.data[0]["title"] == "My Post" + posts_result = await driver.execute(f"SELECT * FROM {posts_table}") + assert len(posts_result.data) == 1 + assert posts_result.data[0]["title"] == "My Post" - await commands.downgrade("0001") + await commands.downgrade("0001") - async with config.provide_session() as driver: - tables_result = await driver.execute("SELECT name FROM sqlite_master WHERE type='table'") - table_names = [t["name"] for t in tables_result.data] - assert users_table in table_names - assert posts_table not in table_names + async with config.provide_session() as driver: + tables_result = await driver.execute("SELECT name FROM sqlite_master WHERE type='table'") + table_names = [t["name"] for t in tables_result.data] + assert users_table in table_names + assert posts_table not in table_names - await commands.downgrade("base") + await commands.downgrade("base") - async with config.provide_session() as driver: - tables_result = await driver.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'" - ) + async with config.provide_session() as driver: + tables_result = await driver.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'" + ) - table_names = [t["name"] for t in tables_result.data if not t["name"].startswith("sqlspec_")] - assert len(table_names) == 0 - finally: - if config.pool_instance: - await config.close_pool() + table_names = [t["name"] for t in tables_result.data if not t["name"].startswith("sqlspec_")] + assert len(table_names) == 0 + finally: + if config.pool_instance: + await config.close_pool() -async def test_aiosqlite_migration_current_command() -> None: +async def test_aiosqlite_migration_current_command(tmp_path: Path) -> None: """Test the current migration command shows correct version for AioSQLite.""" test_id = "aiosqlite_current_cmd" migration_table = f"sqlspec_migrations_{test_id}" test_table = f"test_table_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.db" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.db" - config = AiosqliteConfig( - pool_config={"database": str(db_path)}, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + config = AiosqliteConfig( + pool_config={"database": str(db_path)}, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - try: - await commands.init(str(migration_dir), package=True) + try: + await commands.init(str(migration_dir), package=True) - await commands.current(verbose=False) + await commands.current(verbose=False) - migration_content = f'''"""Test migration.""" + migration_content = f'''"""Test migration.""" def up(): @@ -225,36 +216,35 @@ def down(): return ["DROP TABLE IF EXISTS {test_table}"] ''' - (migration_dir / "0001_test.py").write_text(migration_content) + (migration_dir / "0001_test.py").write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - await commands.current(verbose=True) - finally: - if config.pool_instance: - await config.close_pool() + await commands.current(verbose=True) + finally: + if config.pool_instance: + await config.close_pool() -async def test_aiosqlite_migration_error_handling() -> None: +async def test_aiosqlite_migration_error_handling(tmp_path: Path) -> None: """Test AioSQLite migration error handling.""" test_id = "aiosqlite_error_handling" migration_table = f"sqlspec_migrations_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.db" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.db" - config = AiosqliteConfig( - pool_config={"database": str(db_path)}, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + config = AiosqliteConfig( + pool_config={"database": str(db_path)}, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - try: - await commands.init(str(migration_dir), package=True) + try: + await commands.init(str(migration_dir), package=True) - migration_content = '''"""Bad migration.""" + migration_content = '''"""Bad migration.""" def up(): @@ -267,43 +257,42 @@ def down(): return [] ''' - (migration_dir / "0001_bad.py").write_text(migration_content) - - await commands.upgrade() - - async with config.provide_session() as driver: - try: - await driver.execute(f"SELECT version FROM {migration_table} ORDER BY version") - msg = "Expected migration table to not exist, but it does" - raise AssertionError(msg) - except Exception as e: - assert "no such" in str(e).lower() or "does not exist" in str(e).lower() - finally: - if config.pool_instance: - await config.close_pool() + (migration_dir / "0001_bad.py").write_text(migration_content) + await commands.upgrade() -async def test_aiosqlite_migration_with_transactions() -> None: + async with config.provide_session() as driver: + try: + await driver.execute(f"SELECT version FROM {migration_table} ORDER BY version") + msg = "Expected migration table to not exist, but it does" + raise AssertionError(msg) + except Exception as e: + assert "no such" in str(e).lower() or "does not exist" in str(e).lower() + finally: + if config.pool_instance: + await config.close_pool() + + +async def test_aiosqlite_migration_with_transactions(tmp_path: Path) -> None: """Test AioSQLite migrations work properly with transactions.""" test_id = "aiosqlite_transactions" migration_table = f"sqlspec_migrations_{test_id}" customers_table = f"customers_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.db" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.db" - config = AiosqliteConfig( - pool_config={"database": str(db_path)}, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + config = AiosqliteConfig( + pool_config={"database": str(db_path)}, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - try: - await commands.init(str(migration_dir), package=True) + try: + await commands.init(str(migration_dir), package=True) - migration_content = f'''"""Migration with multiple operations.""" + migration_content = f'''"""Migration with multiple operations.""" def up(): @@ -323,23 +312,23 @@ def down(): return ["DROP TABLE IF EXISTS {customers_table}"] ''' - (migration_dir / "0001_transaction_test.py").write_text(migration_content) + (migration_dir / "0001_transaction_test.py").write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - async with config.provide_session() as driver: - customers_result = await driver.execute(f"SELECT * FROM {customers_table} ORDER BY name") - assert len(customers_result.data) == 2 - assert customers_result.data[0]["name"] == "Customer 1" - assert customers_result.data[1]["name"] == "Customer 2" + async with config.provide_session() as driver: + customers_result = await driver.execute(f"SELECT * FROM {customers_table} ORDER BY name") + assert len(customers_result.data) == 2 + assert customers_result.data[0]["name"] == "Customer 1" + assert customers_result.data[1]["name"] == "Customer 2" - await commands.downgrade("base") + await commands.downgrade("base") - async with config.provide_session() as driver: - result = await driver.execute( - f"SELECT name FROM sqlite_master WHERE type='table' AND name='{customers_table}'" - ) - assert len(result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + result = await driver.execute( + f"SELECT name FROM sqlite_master WHERE type='table' AND name='{customers_table}'" + ) + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() diff --git a/tests/integration/test_adapters/test_asyncmy/test_migrations.py b/tests/integration/test_adapters/test_asyncmy/test_migrations.py index 93ac1c0c..70a4773f 100644 --- a/tests/integration/test_adapters/test_asyncmy/test_migrations.py +++ b/tests/integration/test_adapters/test_asyncmy/test_migrations.py @@ -1,6 +1,5 @@ """Integration tests for Asyncmy (MySQL) migration workflow.""" -import tempfile from pathlib import Path import pytest @@ -12,35 +11,34 @@ pytestmark = pytest.mark.xdist_group("mysql") -async def test_asyncmy_migration_full_workflow(mysql_service: MySQLService) -> None: +async def test_asyncmy_migration_full_workflow(tmp_path: Path, mysql_service: MySQLService) -> None: """Test full Asyncmy migration workflow: init -> create -> upgrade -> downgrade.""" test_id = "asyncmy_full_workflow" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncmyConfig( - pool_config={ - "host": mysql_service.host, - "port": mysql_service.port, - "user": mysql_service.user, - "password": mysql_service.password, - "database": mysql_service.db, - "autocommit": True, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + migration_dir = tmp_path / "migrations" - await commands.init(str(migration_dir), package=True) + config = AsyncmyConfig( + pool_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "database": mysql_service.db, + "autocommit": True, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - assert migration_dir.exists() - assert (migration_dir / "__init__.py").exists() + await commands.init(str(migration_dir), package=True) - migration_content = f'''"""Initial schema migration.""" + assert migration_dir.exists() + assert (migration_dir / "__init__.py").exists() + + migration_content = f'''"""Initial schema migration.""" def up(): @@ -60,42 +58,42 @@ def down(): return ["DROP TABLE IF EXISTS {users_table}"] ''' - migration_file = migration_dir / "0001_create_users.py" - migration_file.write_text(migration_content) + migration_file = migration_dir / "0001_create_users.py" + migration_file.write_text(migration_content) - try: - await commands.upgrade() + try: + await commands.upgrade() - async with config.provide_session() as driver: - result = await driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{users_table}'", - (mysql_service.db,), - ) - assert len(result.data) == 1 + async with config.provide_session() as driver: + result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{users_table}'", + (mysql_service.db,), + ) + assert len(result.data) == 1 - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", ("John Doe", "john@example.com") - ) + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", ("John Doe", "john@example.com") + ) - users_result = await driver.execute(f"SELECT * FROM {users_table}") - assert len(users_result.data) == 1 - assert users_result.data[0]["name"] == "John Doe" - assert users_result.data[0]["email"] == "john@example.com" + users_result = await driver.execute(f"SELECT * FROM {users_table}") + assert len(users_result.data) == 1 + assert users_result.data[0]["name"] == "John Doe" + assert users_result.data[0]["email"] == "john@example.com" - await commands.downgrade("base") + await commands.downgrade("base") - async with config.provide_session() as driver: - result = await driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{users_table}'", - (mysql_service.db,), - ) - assert len(result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{users_table}'", + (mysql_service.db,), + ) + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() -async def test_asyncmy_multiple_migrations_workflow(mysql_service: MySQLService) -> None: +async def test_asyncmy_multiple_migrations_workflow(tmp_path: Path, mysql_service: MySQLService) -> None: """Test Asyncmy workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all.""" test_id = "asyncmy_multiple_workflow" @@ -103,25 +101,24 @@ async def test_asyncmy_multiple_migrations_workflow(mysql_service: MySQLService) users_table = f"users_{test_id}" posts_table = f"posts_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncmyConfig( - pool_config={ - "host": mysql_service.host, - "port": mysql_service.port, - "user": mysql_service.user, - "password": mysql_service.password, - "database": mysql_service.db, - "autocommit": True, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + migration_dir = tmp_path / "migrations" - await commands.init(str(migration_dir), package=True) + config = AsyncmyConfig( + pool_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "database": mysql_service.db, + "autocommit": True, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) + + await commands.init(str(migration_dir), package=True) - migration1_content = f'''"""Create users table.""" + migration1_content = f'''"""Create users table.""" def up(): @@ -140,9 +137,9 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration1_content) + (migration_dir / "0001_create_users.py").write_text(migration1_content) - migration2_content = f'''"""Create posts table.""" + migration2_content = f'''"""Create posts table.""" def up(): @@ -163,88 +160,87 @@ def down(): """Drop posts table.""" return ["DROP TABLE IF EXISTS {posts_table}"] ''' - (migration_dir / "0002_create_posts.py").write_text(migration2_content) - - try: - await commands.upgrade() - - async with config.provide_session() as driver: - users_result = await driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{users_table}'", - (mysql_service.db,), - ) - posts_result = await driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{posts_table}'", - (mysql_service.db,), - ) - assert len(users_result.data) == 1 - assert len(posts_result.data) == 1 - - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", ("John Doe", "john@example.com") - ) - await driver.execute( - f"INSERT INTO {posts_table} (title, content, user_id) VALUES (%s, %s, %s)", - ("Test Post", "This is a test post", 1), - ) - - await commands.downgrade("0001") - - async with config.provide_session() as driver: - users_result = await driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{users_table}'", - (mysql_service.db,), - ) - posts_result = await driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{posts_table}'", - (mysql_service.db,), - ) - assert len(users_result.data) == 1 - assert len(posts_result.data) == 0 - - await commands.downgrade("base") - - async with config.provide_session() as driver: - users_result = await driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name IN ('{users_table}', '{posts_table}')", - (mysql_service.db,), - ) - assert len(users_result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() - - -async def test_asyncmy_migration_current_command(mysql_service: MySQLService) -> None: + (migration_dir / "0002_create_posts.py").write_text(migration2_content) + + try: + await commands.upgrade() + + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{users_table}'", + (mysql_service.db,), + ) + posts_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{posts_table}'", + (mysql_service.db,), + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 1 + + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", ("John Doe", "john@example.com") + ) + await driver.execute( + f"INSERT INTO {posts_table} (title, content, user_id) VALUES (%s, %s, %s)", + ("Test Post", "This is a test post", 1), + ) + + await commands.downgrade("0001") + + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{users_table}'", + (mysql_service.db,), + ) + posts_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{posts_table}'", + (mysql_service.db,), + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 0 + + await commands.downgrade("base") + + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name IN ('{users_table}', '{posts_table}')", + (mysql_service.db,), + ) + assert len(users_result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() + + +async def test_asyncmy_migration_current_command(tmp_path: Path, mysql_service: MySQLService) -> None: """Test the current migration command shows correct version for Asyncmy.""" test_id = "asyncmy_current_cmd" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncmyConfig( - pool_config={ - "host": mysql_service.host, - "port": mysql_service.port, - "user": mysql_service.user, - "password": mysql_service.password, - "database": mysql_service.db, - "autocommit": True, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) - - try: - await commands.init(str(migration_dir), package=True) + migration_dir = tmp_path / "migrations" + + config = AsyncmyConfig( + pool_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "database": mysql_service.db, + "autocommit": True, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) + + try: + await commands.init(str(migration_dir), package=True) - current_version = await commands.current() - assert current_version is None or current_version == "base" + current_version = await commands.current() + assert current_version is None or current_version == "base" - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -261,48 +257,47 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration_content) + (migration_dir / "0001_create_users.py").write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - current_version = await commands.current() - assert current_version == "0001" + current_version = await commands.current() + assert current_version == "0001" - await commands.downgrade("base") + await commands.downgrade("base") - current_version = await commands.current() - assert current_version is None or current_version == "base" - finally: - if config.pool_instance: - await config.close_pool() + current_version = await commands.current() + assert current_version is None or current_version == "base" + finally: + if config.pool_instance: + await config.close_pool() -async def test_asyncmy_migration_error_handling(mysql_service: MySQLService) -> None: +async def test_asyncmy_migration_error_handling(tmp_path: Path, mysql_service: MySQLService) -> None: """Test Asyncmy migration error handling.""" test_id = "asyncmy_error_handling" migration_table = f"sqlspec_migrations_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncmyConfig( - pool_config={ - "host": mysql_service.host, - "port": mysql_service.port, - "user": mysql_service.user, - "password": mysql_service.password, - "database": mysql_service.db, - "autocommit": True, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) - - try: - await commands.init(str(migration_dir), package=True) + migration_dir = tmp_path / "migrations" + + config = AsyncmyConfig( + pool_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "database": mysql_service.db, + "autocommit": True, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) + + try: + await commands.init(str(migration_dir), package=True) - migration_content = '''"""Migration with invalid SQL.""" + migration_content = '''"""Migration with invalid SQL.""" def up(): @@ -314,45 +309,44 @@ def down(): """Drop table.""" return ["DROP TABLE IF EXISTS invalid_table"] ''' - (migration_dir / "0001_invalid.py").write_text(migration_content) + (migration_dir / "0001_invalid.py").write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - async with config.provide_session() as driver: - count = await driver.select_value(f"SELECT COUNT(*) FROM {migration_table}") - assert count == 0, f"Expected empty migration table after failed migration, but found {count} records" - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + count = await driver.select_value(f"SELECT COUNT(*) FROM {migration_table}") + assert count == 0, f"Expected empty migration table after failed migration, but found {count} records" + finally: + if config.pool_instance: + await config.close_pool() -async def test_asyncmy_migration_with_transactions(mysql_service: MySQLService) -> None: +async def test_asyncmy_migration_with_transactions(tmp_path: Path, mysql_service: MySQLService) -> None: """Test Asyncmy migrations work properly with transactions.""" test_id = "asyncmy_transactions" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncmyConfig( - pool_config={ - "host": mysql_service.host, - "port": mysql_service.port, - "user": mysql_service.user, - "password": mysql_service.password, - "database": mysql_service.db, - "autocommit": False, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) - - try: - await commands.init(str(migration_dir), package=True) + migration_dir = tmp_path / "migrations" + + config = AsyncmyConfig( + pool_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "database": mysql_service.db, + "autocommit": False, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) + + try: + await commands.init(str(migration_dir), package=True) - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -370,42 +364,42 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration_content) - - await commands.upgrade() + (migration_dir / "0001_create_users.py").write_text(migration_content) - async with config.provide_session() as driver: - await driver.begin() - try: - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", - ("Transaction User", "trans@example.com"), - ) + await commands.upgrade() - result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") - assert len(result.data) == 1 - await driver.commit() - except Exception: - await driver.rollback() - raise + async with config.provide_session() as driver: + await driver.begin() + try: + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", + ("Transaction User", "trans@example.com"), + ) result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") assert len(result.data) == 1 + await driver.commit() + except Exception: + await driver.rollback() + raise + + result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") + assert len(result.data) == 1 + + async with config.provide_session() as driver: + await driver.begin() + try: + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", + ("Rollback User", "rollback@example.com"), + ) + + raise Exception("Intentional rollback") + except Exception: + await driver.rollback() - async with config.provide_session() as driver: - await driver.begin() - try: - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", - ("Rollback User", "rollback@example.com"), - ) - - raise Exception("Intentional rollback") - except Exception: - await driver.rollback() - - result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Rollback User'") - assert len(result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Rollback User'") + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() diff --git a/tests/integration/test_adapters/test_asyncpg/test_migrations.py b/tests/integration/test_adapters/test_asyncpg/test_migrations.py index d57ac45d..30151304 100644 --- a/tests/integration/test_adapters/test_asyncpg/test_migrations.py +++ b/tests/integration/test_adapters/test_asyncpg/test_migrations.py @@ -1,6 +1,5 @@ """Integration tests for AsyncPG (PostgreSQL) migration workflow.""" -import tempfile from pathlib import Path import pytest @@ -12,32 +11,28 @@ pytestmark = pytest.mark.xdist_group("postgres") -async def test_asyncpg_migration_full_workflow(postgres_service: PostgresService) -> None: +async def test_asyncpg_migration_full_workflow(tmp_path: Path, postgres_service: PostgresService) -> None: """Test full AsyncPG migration workflow: init -> create -> upgrade -> downgrade.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncpgConfig( - pool_config={ - "host": postgres_service.host, - "port": postgres_service.port, - "user": postgres_service.user, - "password": postgres_service.password, - "database": postgres_service.database, - }, - migration_config={ - "script_location": str(migration_dir), - "version_table_name": "sqlspec_migrations_asyncpg", - }, - ) - commands = AsyncMigrationCommands(config) + migration_dir = tmp_path / "migrations" - await commands.init(str(migration_dir), package=True) + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_asyncpg"}, + ) + commands = AsyncMigrationCommands(config) - assert migration_dir.exists() - assert (migration_dir / "__init__.py").exists() + await commands.init(str(migration_dir), package=True) - migration_content = '''"""Initial schema migration.""" + assert migration_dir.exists() + assert (migration_dir / "__init__.py").exists() + + migration_content = '''"""Initial schema migration.""" def up(): @@ -57,62 +52,56 @@ def down(): return ["DROP TABLE IF EXISTS users"] ''' - migration_file = migration_dir / "0001_create_users.py" - migration_file.write_text(migration_content) + migration_file = migration_dir / "0001_create_users.py" + migration_file.write_text(migration_content) - try: - await commands.upgrade() + try: + await commands.upgrade() - async with config.provide_session() as driver: - result = await driver.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'users'" - ) - assert len(result.data) == 1 + async with config.provide_session() as driver: + result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'users'" + ) + assert len(result.data) == 1 - await driver.execute( - "INSERT INTO users (name, email) VALUES ($1, $2)", ("John Doe", "john@example.com") - ) + await driver.execute("INSERT INTO users (name, email) VALUES ($1, $2)", ("John Doe", "john@example.com")) - users_result = await driver.execute("SELECT * FROM users") - assert len(users_result.data) == 1 - assert users_result.data[0]["name"] == "John Doe" - assert users_result.data[0]["email"] == "john@example.com" + users_result = await driver.execute("SELECT * FROM users") + assert len(users_result.data) == 1 + assert users_result.data[0]["name"] == "John Doe" + assert users_result.data[0]["email"] == "john@example.com" - await commands.downgrade("base") + await commands.downgrade("base") - async with config.provide_session() as driver: - result = await driver.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'users'" - ) - assert len(result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'users'" + ) + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() -async def test_asyncpg_multiple_migrations_workflow(postgres_service: PostgresService) -> None: +async def test_asyncpg_multiple_migrations_workflow(tmp_path: Path, postgres_service: PostgresService) -> None: """Test AsyncPG workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncpgConfig( - pool_config={ - "host": postgres_service.host, - "port": postgres_service.port, - "user": postgres_service.user, - "password": postgres_service.password, - "database": postgres_service.database, - }, - migration_config={ - "script_location": str(migration_dir), - "version_table_name": "sqlspec_migrations_asyncpg", - }, - ) - commands = AsyncMigrationCommands(config) + migration_dir = tmp_path / "migrations" - await commands.init(str(migration_dir), package=True) + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_asyncpg"}, + ) + commands = AsyncMigrationCommands(config) - migration1_content = '''"""Create users table.""" + await commands.init(str(migration_dir), package=True) + + migration1_content = '''"""Create users table.""" def up(): @@ -131,9 +120,9 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS users"] ''' - (migration_dir / "0001_create_users.py").write_text(migration1_content) + (migration_dir / "0001_create_users.py").write_text(migration1_content) - migration2_content = '''"""Create posts table.""" + migration2_content = '''"""Create posts table.""" def up(): @@ -153,80 +142,74 @@ def down(): """Drop posts table.""" return ["DROP TABLE IF EXISTS posts"] ''' - (migration_dir / "0002_create_posts.py").write_text(migration2_content) - - try: - await commands.upgrade() - - async with config.provide_session() as driver: - users_result = await driver.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'users'" - ) - posts_result = await driver.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'posts'" - ) - assert len(users_result.data) == 1 - assert len(posts_result.data) == 1 - - await driver.execute( - "INSERT INTO users (name, email) VALUES ($1, $2)", ("John Doe", "john@example.com") - ) - await driver.execute( - "INSERT INTO posts (title, content, user_id) VALUES ($1, $2, $3)", - ("Test Post", "This is a test post", 1), - ) - - await commands.downgrade("0001") - - async with config.provide_session() as driver: - users_result = await driver.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'users'" - ) - posts_result = await driver.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'posts'" - ) - assert len(users_result.data) == 1 - assert len(posts_result.data) == 0 - - await commands.downgrade("base") - - async with config.provide_session() as driver: - users_result = await driver.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name IN ('users', 'posts')" - ) - assert len(users_result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() - - -async def test_asyncpg_migration_current_command(postgres_service: PostgresService) -> None: + (migration_dir / "0002_create_posts.py").write_text(migration2_content) + + try: + await commands.upgrade() + + async with config.provide_session() as driver: + users_result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'users'" + ) + posts_result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'posts'" + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 1 + + await driver.execute("INSERT INTO users (name, email) VALUES ($1, $2)", ("John Doe", "john@example.com")) + await driver.execute( + "INSERT INTO posts (title, content, user_id) VALUES ($1, $2, $3)", + ("Test Post", "This is a test post", 1), + ) + + await commands.downgrade("0001") + + async with config.provide_session() as driver: + users_result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'users'" + ) + posts_result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'posts'" + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 0 + + await commands.downgrade("base") + + async with config.provide_session() as driver: + users_result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name IN ('users', 'posts')" + ) + assert len(users_result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() + + +async def test_asyncpg_migration_current_command(tmp_path: Path, postgres_service: PostgresService) -> None: """Test the current migration command shows correct version for AsyncPG.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncpgConfig( - pool_config={ - "host": postgres_service.host, - "port": postgres_service.port, - "user": postgres_service.user, - "password": postgres_service.password, - "database": postgres_service.database, - }, - migration_config={ - "script_location": str(migration_dir), - "version_table_name": "sqlspec_migrations_asyncpg", - }, - ) - commands = AsyncMigrationCommands(config) - - try: - await commands.init(str(migration_dir), package=True) + migration_dir = tmp_path / "migrations" + + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_asyncpg"}, + ) + commands = AsyncMigrationCommands(config) + + try: + await commands.init(str(migration_dir), package=True) - current_version = await commands.current() - assert current_version is None or current_version == "base" + current_version = await commands.current() + assert current_version is None or current_version == "base" - migration_content = '''"""Initial schema migration.""" + migration_content = '''"""Initial schema migration.""" def up(): @@ -243,46 +226,42 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS users"] ''' - (migration_dir / "0001_create_users.py").write_text(migration_content) + (migration_dir / "0001_create_users.py").write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - current_version = await commands.current() - assert current_version == "0001" + current_version = await commands.current() + assert current_version == "0001" - await commands.downgrade("base") + await commands.downgrade("base") - current_version = await commands.current() - assert current_version is None or current_version == "base" - finally: - if config.pool_instance: - await config.close_pool() + current_version = await commands.current() + assert current_version is None or current_version == "base" + finally: + if config.pool_instance: + await config.close_pool() -async def test_asyncpg_migration_error_handling(postgres_service: PostgresService) -> None: +async def test_asyncpg_migration_error_handling(tmp_path: Path, postgres_service: PostgresService) -> None: """Test AsyncPG migration error handling.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncpgConfig( - pool_config={ - "host": postgres_service.host, - "port": postgres_service.port, - "user": postgres_service.user, - "password": postgres_service.password, - "database": postgres_service.database, - }, - migration_config={ - "script_location": str(migration_dir), - "version_table_name": "sqlspec_migrations_asyncpg", - }, - ) - commands = AsyncMigrationCommands(config) - - try: - await commands.init(str(migration_dir), package=True) + migration_dir = tmp_path / "migrations" + + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_asyncpg"}, + ) + commands = AsyncMigrationCommands(config) + + try: + await commands.init(str(migration_dir), package=True) - migration_content = '''"""Migration with invalid SQL.""" + migration_content = '''"""Migration with invalid SQL.""" def up(): @@ -294,46 +273,42 @@ def down(): """Drop table.""" return ["DROP TABLE IF EXISTS invalid_table"] ''' - (migration_dir / "0001_invalid.py").write_text(migration_content) + (migration_dir / "0001_invalid.py").write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - async with config.provide_session() as driver: - try: - await driver.execute("SELECT version FROM sqlspec_migrations_asyncpg ORDER BY version") - msg = "Expected migration table to not exist, but it does" - raise AssertionError(msg) - except Exception as e: - assert "no such" in str(e).lower() or "does not exist" in str(e).lower() - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + try: + await driver.execute("SELECT version FROM sqlspec_migrations_asyncpg ORDER BY version") + msg = "Expected migration table to not exist, but it does" + raise AssertionError(msg) + except Exception as e: + assert "no such" in str(e).lower() or "does not exist" in str(e).lower() + finally: + if config.pool_instance: + await config.close_pool() -async def test_asyncpg_migration_with_transactions(postgres_service: PostgresService) -> None: +async def test_asyncpg_migration_with_transactions(tmp_path: Path, postgres_service: PostgresService) -> None: """Test AsyncPG migrations work properly with transactions.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncpgConfig( - pool_config={ - "host": postgres_service.host, - "port": postgres_service.port, - "user": postgres_service.user, - "password": postgres_service.password, - "database": postgres_service.database, - }, - migration_config={ - "script_location": str(migration_dir), - "version_table_name": "sqlspec_migrations_asyncpg", - }, - ) - commands = AsyncMigrationCommands(config) - - try: - await commands.init(str(migration_dir), package=True) + migration_dir = tmp_path / "migrations" + + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_asyncpg"}, + ) + commands = AsyncMigrationCommands(config) + + try: + await commands.init(str(migration_dir), package=True) - migration_content = '''"""Initial schema migration.""" + migration_content = '''"""Initial schema migration.""" def up(): @@ -351,68 +326,67 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS users"] ''' - (migration_dir / "0001_create_users.py").write_text(migration_content) + (migration_dir / "0001_create_users.py").write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - async with config.provide_session() as driver: - await driver.begin() - try: - await driver.execute( - "INSERT INTO users (name, email) VALUES ($1, $2)", ("Transaction User", "trans@example.com") - ) - - result = await driver.execute("SELECT * FROM users WHERE name = 'Transaction User'") - assert len(result.data) == 1 - await driver.commit() - except Exception: - await driver.rollback() - raise + async with config.provide_session() as driver: + await driver.begin() + try: + await driver.execute( + "INSERT INTO users (name, email) VALUES ($1, $2)", ("Transaction User", "trans@example.com") + ) result = await driver.execute("SELECT * FROM users WHERE name = 'Transaction User'") assert len(result.data) == 1 + await driver.commit() + except Exception: + await driver.rollback() + raise + + result = await driver.execute("SELECT * FROM users WHERE name = 'Transaction User'") + assert len(result.data) == 1 - async with config.provide_session() as driver: - await driver.begin() - try: - await driver.execute( - "INSERT INTO users (name, email) VALUES ($1, $2)", ("Rollback User", "rollback@example.com") - ) + async with config.provide_session() as driver: + await driver.begin() + try: + await driver.execute( + "INSERT INTO users (name, email) VALUES ($1, $2)", ("Rollback User", "rollback@example.com") + ) - raise Exception("Intentional rollback") - except Exception: - await driver.rollback() + raise Exception("Intentional rollback") + except Exception: + await driver.rollback() - result = await driver.execute("SELECT * FROM users WHERE name = 'Rollback User'") - assert len(result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + result = await driver.execute("SELECT * FROM users WHERE name = 'Rollback User'") + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() -async def test_asyncpg_config_migrate_up_method(postgres_service: PostgresService) -> None: +async def test_asyncpg_config_migrate_up_method(tmp_path: Path, postgres_service: PostgresService) -> None: """Test AsyncpgConfig.migrate_up() method works correctly.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncpgConfig( - pool_config={ - "host": postgres_service.host, - "port": postgres_service.port, - "user": postgres_service.user, - "password": postgres_service.password, - "database": postgres_service.database, - }, - migration_config={ - "script_location": str(migration_dir), - "version_table_name": "sqlspec_migrations_asyncpg_config", - }, - ) + migration_dir = tmp_path / "migrations" + + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": "sqlspec_migrations_asyncpg_config", + }, + ) - try: - await config.init_migrations() + try: + await config.init_migrations() - migration_content = '''"""Create products table.""" + migration_content = '''"""Create products table.""" def up(): @@ -431,43 +405,42 @@ def down(): return ["DROP TABLE IF EXISTS products"] ''' - (migration_dir / "0001_create_products.py").write_text(migration_content) + (migration_dir / "0001_create_products.py").write_text(migration_content) - await config.migrate_up() + await config.migrate_up() - async with config.provide_session() as driver: - result = await driver.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'products'" - ) - assert len(result.data) == 1 - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'products'" + ) + assert len(result.data) == 1 + finally: + if config.pool_instance: + await config.close_pool() -async def test_asyncpg_config_migrate_down_method(postgres_service: PostgresService) -> None: +async def test_asyncpg_config_migrate_down_method(tmp_path: Path, postgres_service: PostgresService) -> None: """Test AsyncpgConfig.migrate_down() method works correctly.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncpgConfig( - pool_config={ - "host": postgres_service.host, - "port": postgres_service.port, - "user": postgres_service.user, - "password": postgres_service.password, - "database": postgres_service.database, - }, - migration_config={ - "script_location": str(migration_dir), - "version_table_name": "sqlspec_migrations_asyncpg_down", - }, - ) + migration_dir = tmp_path / "migrations" + + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": "sqlspec_migrations_asyncpg_down", + }, + ) - try: - await config.init_migrations() + try: + await config.init_migrations() - migration_content = '''"""Create inventory table.""" + migration_content = '''"""Create inventory table.""" def up(): @@ -485,54 +458,50 @@ def down(): return ["DROP TABLE IF EXISTS inventory"] ''' - (migration_dir / "0001_create_inventory.py").write_text(migration_content) + (migration_dir / "0001_create_inventory.py").write_text(migration_content) - await config.migrate_up() + await config.migrate_up() - async with config.provide_session() as driver: - result = await driver.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'inventory'" - ) - assert len(result.data) == 1 + async with config.provide_session() as driver: + result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'inventory'" + ) + assert len(result.data) == 1 - await config.migrate_down() + await config.migrate_down() - async with config.provide_session() as driver: - result = await driver.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'inventory'" - ) - assert len(result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'inventory'" + ) + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() -async def test_asyncpg_config_get_current_migration_method(postgres_service: PostgresService) -> None: +async def test_asyncpg_config_get_current_migration_method(tmp_path: Path, postgres_service: PostgresService) -> None: """Test AsyncpgConfig.get_current_migration() method returns correct version.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncpgConfig( - pool_config={ - "host": postgres_service.host, - "port": postgres_service.port, - "user": postgres_service.user, - "password": postgres_service.password, - "database": postgres_service.database, - }, - migration_config={ - "script_location": str(migration_dir), - "version_table_name": "sqlspec_migrations_current", - }, - ) + migration_dir = tmp_path / "migrations" - try: - await config.init_migrations() + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_current"}, + ) - current_version = await config.get_current_migration() - assert current_version is None or current_version == "base" + try: + await config.init_migrations() - migration_content = '''"""First migration.""" + current_version = await config.get_current_migration() + assert current_version is None or current_version == "base" + + migration_content = '''"""First migration.""" def up(): @@ -545,68 +514,66 @@ def down(): return ["DROP TABLE IF EXISTS test_version"] ''' - (migration_dir / "0001_first.py").write_text(migration_content) + (migration_dir / "0001_first.py").write_text(migration_content) - await config.migrate_up() + await config.migrate_up() - current_version = await config.get_current_migration() - assert current_version == "0001" - finally: - if config.pool_instance: - await config.close_pool() + current_version = await config.get_current_migration() + assert current_version == "0001" + finally: + if config.pool_instance: + await config.close_pool() -async def test_asyncpg_config_create_migration_method(postgres_service: PostgresService) -> None: +async def test_asyncpg_config_create_migration_method(tmp_path: Path, postgres_service: PostgresService) -> None: """Test AsyncpgConfig.create_migration() method generates migration file.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncpgConfig( - pool_config={ - "host": postgres_service.host, - "port": postgres_service.port, - "user": postgres_service.user, - "password": postgres_service.password, - "database": postgres_service.database, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_create"}, - ) + migration_dir = tmp_path / "migrations" - try: - await config.init_migrations() + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_create"}, + ) - await config.create_migration("add users table", file_type="py") + try: + await config.init_migrations() - migration_files = list(migration_dir.glob("*.py")) - migration_files = [f for f in migration_files if f.name != "__init__.py"] + await config.create_migration("add users table", file_type="py") - assert len(migration_files) == 1 - assert "add_users_table" in migration_files[0].name - finally: - if config.pool_instance: - await config.close_pool() + migration_files = list(migration_dir.glob("*.py")) + migration_files = [f for f in migration_files if f.name != "__init__.py"] + assert len(migration_files) == 1 + assert "add_users_table" in migration_files[0].name + finally: + if config.pool_instance: + await config.close_pool() -async def test_asyncpg_config_stamp_migration_method(postgres_service: PostgresService) -> None: + +async def test_asyncpg_config_stamp_migration_method(tmp_path: Path, postgres_service: PostgresService) -> None: """Test AsyncpgConfig.stamp_migration() method marks database at revision.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncpgConfig( - pool_config={ - "host": postgres_service.host, - "port": postgres_service.port, - "user": postgres_service.user, - "password": postgres_service.password, - "database": postgres_service.database, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_stamp"}, - ) + migration_dir = tmp_path / "migrations" - try: - await config.init_migrations() + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_stamp"}, + ) - migration_content = '''"""Stamped migration.""" + try: + await config.init_migrations() + + migration_content = '''"""Stamped migration.""" def up(): @@ -619,43 +586,42 @@ def down(): return ["DROP TABLE IF EXISTS stamped"] ''' - (migration_dir / "0001_stamped.py").write_text(migration_content) + (migration_dir / "0001_stamped.py").write_text(migration_content) - await config.stamp_migration("0001") + await config.stamp_migration("0001") - current_version = await config.get_current_migration() - assert current_version == "0001" + current_version = await config.get_current_migration() + assert current_version == "0001" - async with config.provide_session() as driver: - result = await driver.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'stamped'" - ) - assert len(result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'stamped'" + ) + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() -async def test_asyncpg_config_fix_migrations_dry_run(postgres_service: PostgresService) -> None: +async def test_asyncpg_config_fix_migrations_dry_run(tmp_path: Path, postgres_service: PostgresService) -> None: """Test AsyncpgConfig.fix_migrations() dry run shows what would change.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = AsyncpgConfig( - pool_config={ - "host": postgres_service.host, - "port": postgres_service.port, - "user": postgres_service.user, - "password": postgres_service.password, - "database": postgres_service.database, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_fix"}, - ) + migration_dir = tmp_path / "migrations" + + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_fix"}, + ) - try: - await config.init_migrations() + try: + await config.init_migrations() - timestamp_migration = '''"""Timestamp migration.""" + timestamp_migration = '''"""Timestamp migration.""" def up(): @@ -668,15 +634,15 @@ def down(): return ["DROP TABLE IF EXISTS timestamp_test"] ''' - (migration_dir / "20251030120000_timestamp_migration.py").write_text(timestamp_migration) + (migration_dir / "20251030120000_timestamp_migration.py").write_text(timestamp_migration) - await config.fix_migrations(dry_run=True, yes=True) + await config.fix_migrations(dry_run=True, yes=True) - timestamp_file = migration_dir / "20251030120000_timestamp_migration.py" - assert timestamp_file.exists() + timestamp_file = migration_dir / "20251030120000_timestamp_migration.py" + assert timestamp_file.exists() - sequential_file = migration_dir / "0001_timestamp_migration.py" - assert not sequential_file.exists() - finally: - if config.pool_instance: - await config.close_pool() + sequential_file = migration_dir / "0001_timestamp_migration.py" + assert not sequential_file.exists() + finally: + if config.pool_instance: + await config.close_pool() diff --git a/tests/integration/test_adapters/test_duckdb/test_migrations.py b/tests/integration/test_adapters/test_duckdb/test_migrations.py index 5b9b350f..d381816c 100644 --- a/tests/integration/test_adapters/test_duckdb/test_migrations.py +++ b/tests/integration/test_adapters/test_duckdb/test_migrations.py @@ -1,6 +1,5 @@ """Integration tests for DuckDB migration workflow.""" -import tempfile from pathlib import Path from typing import Any @@ -12,24 +11,23 @@ pytestmark = pytest.mark.xdist_group("duckdb") -def test_duckdb_migration_full_workflow() -> None: +def test_duckdb_migration_full_workflow(tmp_path: Path) -> None: """Test full DuckDB migration workflow: init -> create -> upgrade -> downgrade.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.duckdb" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.duckdb" - config = DuckDBConfig( - pool_config={"database": str(db_path)}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + config = DuckDBConfig( + pool_config={"database": str(db_path)}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - assert migration_dir.exists() - assert (migration_dir / "__init__.py").exists() + assert migration_dir.exists() + assert (migration_dir / "__init__.py").exists() - migration_content = '''"""Initial schema migration.""" + migration_content = '''"""Initial schema migration.""" def up(): @@ -49,44 +47,43 @@ def down(): return ["DROP TABLE IF EXISTS users"] ''' - migration_file = migration_dir / "0001_create_users.py" - migration_file.write_text(migration_content) + migration_file = migration_dir / "0001_create_users.py" + migration_file.write_text(migration_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - result = driver.execute("SELECT table_name FROM information_schema.tables WHERE table_name = 'users'") - assert len(result.data) == 1 + with config.provide_session() as driver: + result = driver.execute("SELECT table_name FROM information_schema.tables WHERE table_name = 'users'") + assert len(result.data) == 1 - driver.execute("INSERT INTO users (id, name, email) VALUES (?, ?, ?)", (1, "John Doe", "john@example.com")) + driver.execute("INSERT INTO users (id, name, email) VALUES (?, ?, ?)", (1, "John Doe", "john@example.com")) - users_result = driver.execute("SELECT * FROM users") - assert len(users_result.data) == 1 - assert users_result.data[0]["name"] == "John Doe" - assert users_result.data[0]["email"] == "john@example.com" + users_result = driver.execute("SELECT * FROM users") + assert len(users_result.data) == 1 + assert users_result.data[0]["name"] == "John Doe" + assert users_result.data[0]["email"] == "john@example.com" - commands.downgrade("base") + commands.downgrade("base") - with config.provide_session() as driver: - result = driver.execute("SELECT table_name FROM information_schema.tables WHERE table_name = 'users'") - assert len(result.data) == 0 + with config.provide_session() as driver: + result = driver.execute("SELECT table_name FROM information_schema.tables WHERE table_name = 'users'") + assert len(result.data) == 0 -def test_duckdb_multiple_migrations_workflow() -> None: +def test_duckdb_multiple_migrations_workflow(tmp_path: Path) -> None: """Test DuckDB workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.duckdb" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.duckdb" - config = DuckDBConfig( - pool_config={"database": str(db_path)}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + config = DuckDBConfig( + pool_config={"database": str(db_path)}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - migration1_content = '''"""Create users table.""" + migration1_content = '''"""Create users table.""" def up(): @@ -105,7 +102,7 @@ def down(): return ["DROP TABLE IF EXISTS users"] ''' - migration2_content = '''"""Create posts table.""" + migration2_content = '''"""Create posts table.""" def up(): @@ -126,66 +123,63 @@ def down(): return ["DROP TABLE IF EXISTS posts"] ''' - (migration_dir / "0001_create_users.py").write_text(migration1_content) - (migration_dir / "0002_create_posts.py").write_text(migration2_content) + (migration_dir / "0001_create_users.py").write_text(migration1_content) + (migration_dir / "0002_create_posts.py").write_text(migration2_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - tables_result = driver.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'main' ORDER BY table_name" - ) - table_names = [t["table_name"] for t in tables_result.data] - assert "users" in table_names - assert "posts" in table_names + with config.provide_session() as driver: + tables_result = driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'main' ORDER BY table_name" + ) + table_names = [t["table_name"] for t in tables_result.data] + assert "users" in table_names + assert "posts" in table_names - driver.execute("INSERT INTO users (id, name, email) VALUES (?, ?, ?)", (1, "Author", "author@example.com")) - driver.execute( - "INSERT INTO posts (id, title, content, user_id) VALUES (?, ?, ?, ?)", (1, "My Post", "Post content", 1) - ) + driver.execute("INSERT INTO users (id, name, email) VALUES (?, ?, ?)", (1, "Author", "author@example.com")) + driver.execute( + "INSERT INTO posts (id, title, content, user_id) VALUES (?, ?, ?, ?)", (1, "My Post", "Post content", 1) + ) - posts_result = driver.execute("SELECT * FROM posts") - assert len(posts_result.data) == 1 - assert posts_result.data[0]["title"] == "My Post" + posts_result = driver.execute("SELECT * FROM posts") + assert len(posts_result.data) == 1 + assert posts_result.data[0]["title"] == "My Post" - commands.downgrade("0001") + commands.downgrade("0001") - with config.provide_session() as driver: - tables_result = driver.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'main'" - ) - table_names = [t["table_name"] for t in tables_result.data] - assert "users" in table_names - assert "posts" not in table_names + with config.provide_session() as driver: + tables_result = driver.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'main'") + table_names = [t["table_name"] for t in tables_result.data] + assert "users" in table_names + assert "posts" not in table_names - commands.downgrade("base") + commands.downgrade("base") - with config.provide_session() as driver: - tables_result = driver.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'main' AND table_name NOT LIKE 'sqlspec_%'" - ) + with config.provide_session() as driver: + tables_result = driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'main' AND table_name NOT LIKE 'sqlspec_%'" + ) - table_names = [t["table_name"] for t in tables_result.data if not t["table_name"].startswith("sqlspec_")] - assert len(table_names) == 0 + table_names = [t["table_name"] for t in tables_result.data if not t["table_name"].startswith("sqlspec_")] + assert len(table_names) == 0 -def test_duckdb_migration_current_command() -> None: +def test_duckdb_migration_current_command(tmp_path: Path) -> None: """Test the current migration command shows correct version for DuckDB.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.duckdb" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.duckdb" - config = DuckDBConfig( - pool_config={"database": str(db_path)}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + config = DuckDBConfig( + pool_config={"database": str(db_path)}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - commands.current(verbose=False) + commands.current(verbose=False) - migration_content = '''"""Test migration.""" + migration_content = '''"""Test migration.""" def up(): @@ -198,28 +192,27 @@ def down(): return ["DROP TABLE IF EXISTS test_table"] ''' - (migration_dir / "0001_test.py").write_text(migration_content) + (migration_dir / "0001_test.py").write_text(migration_content) - commands.upgrade() + commands.upgrade() - commands.current(verbose=True) + commands.current(verbose=True) -def test_duckdb_migration_error_handling() -> None: +def test_duckdb_migration_error_handling(tmp_path: Path) -> None: """Test DuckDB migration error handling.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.duckdb" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.duckdb" - config = DuckDBConfig( - pool_config={"database": str(db_path)}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + config = DuckDBConfig( + pool_config={"database": str(db_path)}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - migration_content = '''"""Bad migration.""" + migration_content = '''"""Bad migration.""" def up(): @@ -232,30 +225,29 @@ def down(): return [] ''' - (migration_dir / "0001_bad.py").write_text(migration_content) + (migration_dir / "0001_bad.py").write_text(migration_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - count = driver.select_value("SELECT COUNT(*) FROM sqlspec_migrations") - assert count == 0, f"Expected empty migration table after failed migration, but found {count} records" + with config.provide_session() as driver: + count = driver.select_value("SELECT COUNT(*) FROM sqlspec_migrations") + assert count == 0, f"Expected empty migration table after failed migration, but found {count} records" -def test_duckdb_migration_with_transactions() -> None: +def test_duckdb_migration_with_transactions(tmp_path: Path) -> None: """Test DuckDB migrations work properly with transactions.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - db_path = Path(temp_dir) / "test.duckdb" + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.duckdb" - config = DuckDBConfig( - pool_config={"database": str(db_path)}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + config = DuckDBConfig( + pool_config={"database": str(db_path)}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - migration_content = '''"""Migration with multiple operations.""" + migration_content = '''"""Migration with multiple operations.""" def up(): @@ -275,18 +267,18 @@ def down(): return ["DROP TABLE IF EXISTS customers"] ''' - (migration_dir / "0001_transaction_test.py").write_text(migration_content) + (migration_dir / "0001_transaction_test.py").write_text(migration_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - customers_result = driver.execute("SELECT * FROM customers ORDER BY name") - assert len(customers_result.data) == 2 - assert customers_result.data[0]["name"] == "Customer 1" - assert customers_result.data[1]["name"] == "Customer 2" + with config.provide_session() as driver: + customers_result = driver.execute("SELECT * FROM customers ORDER BY name") + assert len(customers_result.data) == 2 + assert customers_result.data[0]["name"] == "Customer 1" + assert customers_result.data[1]["name"] == "Customer 2" - commands.downgrade("base") + commands.downgrade("base") - with config.provide_session() as driver: - result = driver.execute("SELECT table_name FROM information_schema.tables WHERE table_name = 'customers'") - assert len(result.data) == 0 + with config.provide_session() as driver: + result = driver.execute("SELECT table_name FROM information_schema.tables WHERE table_name = 'customers'") + assert len(result.data) == 0 diff --git a/tests/integration/test_adapters/test_oracledb/test_migrations.py b/tests/integration/test_adapters/test_oracledb/test_migrations.py index 1bc964ba..185900c9 100644 --- a/tests/integration/test_adapters/test_oracledb/test_migrations.py +++ b/tests/integration/test_adapters/test_oracledb/test_migrations.py @@ -1,6 +1,5 @@ """Integration tests for OracleDB migration workflow.""" -import tempfile from pathlib import Path from typing import Any @@ -13,34 +12,33 @@ pytestmark = pytest.mark.xdist_group("oracle") -def test_oracledb_sync_migration_full_workflow(oracle_23ai_service: OracleService) -> None: +def test_oracledb_sync_migration_full_workflow(tmp_path: Path, oracle_23ai_service: OracleService) -> None: """Test full OracleDB sync migration workflow: init -> create -> upgrade -> downgrade.""" test_id = "oracledb_sync_full" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = OracleSyncConfig( - pool_config={ - "host": oracle_23ai_service.host, - "port": oracle_23ai_service.port, - "service_name": oracle_23ai_service.service_name, - "user": oracle_23ai_service.user, - "password": oracle_23ai_service.password, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + migration_dir = tmp_path / "migrations" - commands.init(str(migration_dir), package=True) + config = OracleSyncConfig( + pool_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - assert migration_dir.exists() - assert (migration_dir / "__init__.py").exists() + commands.init(str(migration_dir), package=True) - migration_content = f'''"""Initial schema migration.""" + assert migration_dir.exists() + assert (migration_dir / "__init__.py").exists() + + migration_content = f'''"""Initial schema migration.""" def up(): @@ -60,69 +58,62 @@ def down(): return ["DROP TABLE {users_table}"] ''' - migration_file = migration_dir / "0001_create_users.py" - migration_file.write_text(migration_content) + migration_file = migration_dir / "0001_create_users.py" + migration_file.write_text(migration_content) - try: - commands.upgrade() + try: + commands.upgrade() - with config.provide_session() as driver: - result = driver.execute( - f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'" - ) - assert len(result.data) == 1 + with config.provide_session() as driver: + result = driver.execute(f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'") + assert len(result.data) == 1 - driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", ("John Doe", "john@example.com") - ) + driver.execute(f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", ("John Doe", "john@example.com")) - users_result = driver.execute(f"SELECT * FROM {users_table}") - assert len(users_result.data) == 1 - assert users_result.data[0]["name"] == "John Doe" - assert users_result.data[0]["email"] == "john@example.com" + users_result = driver.execute(f"SELECT * FROM {users_table}") + assert len(users_result.data) == 1 + assert users_result.data[0]["name"] == "John Doe" + assert users_result.data[0]["email"] == "john@example.com" - commands.downgrade("base") + commands.downgrade("base") - with config.provide_session() as driver: - result = driver.execute( - f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'" - ) - assert len(result.data) == 0 - finally: - if config.pool_instance: - config.close_pool() + with config.provide_session() as driver: + result = driver.execute(f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'") + assert len(result.data) == 0 + finally: + if config.pool_instance: + config.close_pool() -async def test_oracledb_async_migration_full_workflow(oracle_23ai_service: OracleService) -> None: +async def test_oracledb_async_migration_full_workflow(tmp_path: Path, oracle_23ai_service: OracleService) -> None: """Test full OracleDB async migration workflow: init -> create -> upgrade -> downgrade.""" test_id = "oracledb_async_full" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = OracleAsyncConfig( - pool_config={ - "host": oracle_23ai_service.host, - "port": oracle_23ai_service.port, - "service_name": oracle_23ai_service.service_name, - "user": oracle_23ai_service.user, - "password": oracle_23ai_service.password, - "min": 1, - "max": 5, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + migration_dir = tmp_path / "migrations" - await commands.init(str(migration_dir), package=True) + config = OracleAsyncConfig( + pool_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + "min": 1, + "max": 5, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - assert migration_dir.exists() - assert (migration_dir / "__init__.py").exists() + await commands.init(str(migration_dir), package=True) - migration_content = f'''"""Initial schema migration.""" + assert migration_dir.exists() + assert (migration_dir / "__init__.py").exists() + + migration_content = f'''"""Initial schema migration.""" def up(): @@ -142,40 +133,40 @@ def down(): return ["DROP TABLE {users_table}"] ''' - migration_file = migration_dir / "0001_create_users.py" - migration_file.write_text(migration_content) + migration_file = migration_dir / "0001_create_users.py" + migration_file.write_text(migration_content) - try: - await commands.upgrade() + try: + await commands.upgrade() - async with config.provide_session() as driver: - result = await driver.execute( - f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'" - ) - assert len(result.data) == 1 + async with config.provide_session() as driver: + result = await driver.execute( + f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'" + ) + assert len(result.data) == 1 - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", ("John Doe", "john@example.com") - ) + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", ("John Doe", "john@example.com") + ) - users_result = await driver.execute(f"SELECT * FROM {users_table}") - assert len(users_result.data) == 1 - assert users_result.data[0]["name"] == "John Doe" - assert users_result.data[0]["email"] == "john@example.com" + users_result = await driver.execute(f"SELECT * FROM {users_table}") + assert len(users_result.data) == 1 + assert users_result.data[0]["name"] == "John Doe" + assert users_result.data[0]["email"] == "john@example.com" - await commands.downgrade("base") + await commands.downgrade("base") - async with config.provide_session() as driver: - result = await driver.execute( - f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'" - ) - assert len(result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + result = await driver.execute( + f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'" + ) + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() -def test_oracledb_sync_multiple_migrations_workflow(oracle_23ai_service: OracleService) -> None: +def test_oracledb_sync_multiple_migrations_workflow(tmp_path: Path, oracle_23ai_service: OracleService) -> None: """Test OracleDB sync workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all.""" test_id = "oracledb_sync_multiple" @@ -183,24 +174,23 @@ def test_oracledb_sync_multiple_migrations_workflow(oracle_23ai_service: OracleS users_table = f"users_{test_id}" posts_table = f"posts_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = OracleSyncConfig( - pool_config={ - "host": oracle_23ai_service.host, - "port": oracle_23ai_service.port, - "service_name": oracle_23ai_service.service_name, - "user": oracle_23ai_service.user, - "password": oracle_23ai_service.password, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + migration_dir = tmp_path / "migrations" - commands.init(str(migration_dir), package=True) + config = OracleSyncConfig( + pool_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - migration1_content = f'''"""Create users table.""" + commands.init(str(migration_dir), package=True) + + migration1_content = f'''"""Create users table.""" def up(): @@ -219,9 +209,9 @@ def down(): """Drop users table.""" return ["DROP TABLE {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration1_content) + (migration_dir / "0001_create_users.py").write_text(migration1_content) - migration2_content = f'''"""Create posts table.""" + migration2_content = f'''"""Create posts table.""" def up(): @@ -242,54 +232,52 @@ def down(): """Drop posts table.""" return ["DROP TABLE {posts_table}"] ''' - (migration_dir / "0002_create_posts.py").write_text(migration2_content) - - try: - commands.upgrade() + (migration_dir / "0002_create_posts.py").write_text(migration2_content) - with config.provide_session() as driver: - users_result = driver.execute( - f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'" - ) - posts_result = driver.execute( - f"SELECT table_name FROM user_tables WHERE table_name = '{posts_table.upper()}'" - ) - assert len(users_result.data) == 1 - assert len(posts_result.data) == 1 - - driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", ("John Doe", "john@example.com") - ) - driver.execute( - f"INSERT INTO {posts_table} (title, content, user_id) VALUES (:1, :2, :3)", - ("Test Post", "This is a test post", 1), - ) + try: + commands.upgrade() - commands.downgrade("0001") + with config.provide_session() as driver: + users_result = driver.execute( + f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'" + ) + posts_result = driver.execute( + f"SELECT table_name FROM user_tables WHERE table_name = '{posts_table.upper()}'" + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 1 + + driver.execute(f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", ("John Doe", "john@example.com")) + driver.execute( + f"INSERT INTO {posts_table} (title, content, user_id) VALUES (:1, :2, :3)", + ("Test Post", "This is a test post", 1), + ) + + commands.downgrade("0001") - with config.provide_session() as driver: - users_result = driver.execute( - f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'" - ) - posts_result = driver.execute( - f"SELECT table_name FROM user_tables WHERE table_name = '{posts_table.upper()}'" - ) - assert len(users_result.data) == 1 - assert len(posts_result.data) == 0 + with config.provide_session() as driver: + users_result = driver.execute( + f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'" + ) + posts_result = driver.execute( + f"SELECT table_name FROM user_tables WHERE table_name = '{posts_table.upper()}'" + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 0 - commands.downgrade("base") + commands.downgrade("base") - with config.provide_session() as driver: - users_result = driver.execute( - f"SELECT table_name FROM user_tables WHERE table_name IN ('{users_table.upper()}', '{posts_table.upper()}')" - ) - assert len(users_result.data) == 0 - finally: - if config.pool_instance: - config.close_pool() + with config.provide_session() as driver: + users_result = driver.execute( + f"SELECT table_name FROM user_tables WHERE table_name IN ('{users_table.upper()}', '{posts_table.upper()}')" + ) + assert len(users_result.data) == 0 + finally: + if config.pool_instance: + config.close_pool() -async def test_oracledb_async_multiple_migrations_workflow(oracle_23ai_service: OracleService) -> None: +async def test_oracledb_async_multiple_migrations_workflow(tmp_path: Path, oracle_23ai_service: OracleService) -> None: """Test OracleDB async workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all.""" test_id = "oracledb_async_multiple" @@ -297,26 +285,25 @@ async def test_oracledb_async_multiple_migrations_workflow(oracle_23ai_service: users_table = f"users_{test_id}" posts_table = f"posts_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = OracleAsyncConfig( - pool_config={ - "host": oracle_23ai_service.host, - "port": oracle_23ai_service.port, - "service_name": oracle_23ai_service.service_name, - "user": oracle_23ai_service.user, - "password": oracle_23ai_service.password, - "min": 1, - "max": 5, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + migration_dir = tmp_path / "migrations" - await commands.init(str(migration_dir), package=True) + config = OracleAsyncConfig( + pool_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + "min": 1, + "max": 5, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - migration1_content = f'''"""Create users table.""" + await commands.init(str(migration_dir), package=True) + + migration1_content = f'''"""Create users table.""" def up(): @@ -335,9 +322,9 @@ def down(): """Drop users table.""" return ["DROP TABLE {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration1_content) + (migration_dir / "0001_create_users.py").write_text(migration1_content) - migration2_content = f'''"""Create posts table.""" + migration2_content = f'''"""Create posts table.""" def up(): @@ -358,82 +345,81 @@ def down(): """Drop posts table.""" return ["DROP TABLE {posts_table}"] ''' - (migration_dir / "0002_create_posts.py").write_text(migration2_content) - - try: - await commands.upgrade() + (migration_dir / "0002_create_posts.py").write_text(migration2_content) - async with config.provide_session() as driver: - users_result = await driver.execute( - f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'" - ) - posts_result = await driver.execute( - f"SELECT table_name FROM user_tables WHERE table_name = '{posts_table.upper()}'" - ) - assert len(users_result.data) == 1 - assert len(posts_result.data) == 1 - - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", ("John Doe", "john@example.com") - ) - await driver.execute( - f"INSERT INTO {posts_table} (title, content, user_id) VALUES (:1, :2, :3)", - ("Test Post", "This is a test post", 1), - ) + try: + await commands.upgrade() - await commands.downgrade("0001") + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'" + ) + posts_result = await driver.execute( + f"SELECT table_name FROM user_tables WHERE table_name = '{posts_table.upper()}'" + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 1 + + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", ("John Doe", "john@example.com") + ) + await driver.execute( + f"INSERT INTO {posts_table} (title, content, user_id) VALUES (:1, :2, :3)", + ("Test Post", "This is a test post", 1), + ) + + await commands.downgrade("0001") - async with config.provide_session() as driver: - users_result = await driver.execute( - f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'" - ) - posts_result = await driver.execute( - f"SELECT table_name FROM user_tables WHERE table_name = '{posts_table.upper()}'" - ) - assert len(users_result.data) == 1 - assert len(posts_result.data) == 0 + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT table_name FROM user_tables WHERE table_name = '{users_table.upper()}'" + ) + posts_result = await driver.execute( + f"SELECT table_name FROM user_tables WHERE table_name = '{posts_table.upper()}'" + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 0 - await commands.downgrade("base") + await commands.downgrade("base") - async with config.provide_session() as driver: - users_result = await driver.execute( - f"SELECT table_name FROM user_tables WHERE table_name IN ('{users_table.upper()}', '{posts_table.upper()}')" - ) - assert len(users_result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT table_name FROM user_tables WHERE table_name IN ('{users_table.upper()}', '{posts_table.upper()}')" + ) + assert len(users_result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() -def test_oracledb_sync_migration_current_command(oracle_23ai_service: OracleService) -> None: +def test_oracledb_sync_migration_current_command(tmp_path: Path, oracle_23ai_service: OracleService) -> None: """Test the current migration command shows correct version for OracleDB sync.""" test_id = "oracledb_sync_current" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = OracleSyncConfig( - pool_config={ - "host": oracle_23ai_service.host, - "port": oracle_23ai_service.port, - "service_name": oracle_23ai_service.service_name, - "user": oracle_23ai_service.user, - "password": oracle_23ai_service.password, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + migration_dir = tmp_path / "migrations" + + config = OracleSyncConfig( + pool_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - try: - commands.init(str(migration_dir), package=True) + try: + commands.init(str(migration_dir), package=True) - current_version = commands.current() - assert current_version is None or current_version == "base" + current_version = commands.current() + assert current_version is None or current_version == "base" - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -450,53 +436,52 @@ def down(): """Drop users table.""" return ["DROP TABLE {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration_content) + (migration_dir / "0001_create_users.py").write_text(migration_content) - commands.upgrade() + commands.upgrade() - current_version = commands.current() - assert current_version == "0001" + current_version = commands.current() + assert current_version == "0001" - commands.downgrade("base") + commands.downgrade("base") - current_version = commands.current() - assert current_version is None or current_version == "base" - finally: - if config.pool_instance: - config.close_pool() + current_version = commands.current() + assert current_version is None or current_version == "base" + finally: + if config.pool_instance: + config.close_pool() -async def test_oracledb_async_migration_current_command(oracle_23ai_service: OracleService) -> None: +async def test_oracledb_async_migration_current_command(tmp_path: Path, oracle_23ai_service: OracleService) -> None: """Test the current migration command shows correct version for OracleDB async.""" test_id = "oracledb_async_current" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = OracleAsyncConfig( - pool_config={ - "host": oracle_23ai_service.host, - "port": oracle_23ai_service.port, - "service_name": oracle_23ai_service.service_name, - "user": oracle_23ai_service.user, - "password": oracle_23ai_service.password, - "min": 1, - "max": 5, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + migration_dir = tmp_path / "migrations" + + config = OracleAsyncConfig( + pool_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + "min": 1, + "max": 5, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - try: - await commands.init(str(migration_dir), package=True) + try: + await commands.init(str(migration_dir), package=True) - current_version = await commands.current() - assert current_version is None or current_version == "base" + current_version = await commands.current() + assert current_version is None or current_version == "base" - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -513,47 +498,46 @@ def down(): """Drop users table.""" return ["DROP TABLE {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration_content) + (migration_dir / "0001_create_users.py").write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - current_version = await commands.current() - assert current_version == "0001" + current_version = await commands.current() + assert current_version == "0001" - await commands.downgrade("base") + await commands.downgrade("base") - current_version = await commands.current() - assert current_version is None or current_version == "base" - finally: - if config.pool_instance: - await config.close_pool() + current_version = await commands.current() + assert current_version is None or current_version == "base" + finally: + if config.pool_instance: + await config.close_pool() -def test_oracledb_sync_migration_error_handling(oracle_23ai_service: OracleService) -> None: +def test_oracledb_sync_migration_error_handling(tmp_path: Path, oracle_23ai_service: OracleService) -> None: """Test OracleDB sync migration error handling.""" test_id = "oracledb_sync_error" migration_table = f"sqlspec_migrations_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = OracleSyncConfig( - pool_config={ - "host": oracle_23ai_service.host, - "port": oracle_23ai_service.port, - "service_name": oracle_23ai_service.service_name, - "user": oracle_23ai_service.user, - "password": oracle_23ai_service.password, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + migration_dir = tmp_path / "migrations" - try: - commands.init(str(migration_dir), package=True) + config = OracleSyncConfig( + pool_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - migration_content = '''"""Migration with invalid SQL.""" + try: + commands.init(str(migration_dir), package=True) + + migration_content = '''"""Migration with invalid SQL.""" def up(): @@ -565,45 +549,44 @@ def down(): """Drop table.""" return ["DROP TABLE invalid_table"] ''' - (migration_dir / "0001_invalid.py").write_text(migration_content) + (migration_dir / "0001_invalid.py").write_text(migration_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - count = driver.select_value(f"SELECT COUNT(*) FROM {migration_table}") - assert count == 0, f"Expected empty migration table after failed migration, but found {count} records" - finally: - if config.pool_instance: - config.close_pool() + with config.provide_session() as driver: + count = driver.select_value(f"SELECT COUNT(*) FROM {migration_table}") + assert count == 0, f"Expected empty migration table after failed migration, but found {count} records" + finally: + if config.pool_instance: + config.close_pool() -async def test_oracledb_async_migration_error_handling(oracle_23ai_service: OracleService) -> None: +async def test_oracledb_async_migration_error_handling(tmp_path: Path, oracle_23ai_service: OracleService) -> None: """Test OracleDB async migration error handling.""" test_id = "oracledb_async_error" migration_table = f"sqlspec_migrations_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = OracleAsyncConfig( - pool_config={ - "host": oracle_23ai_service.host, - "port": oracle_23ai_service.port, - "service_name": oracle_23ai_service.service_name, - "user": oracle_23ai_service.user, - "password": oracle_23ai_service.password, - "min": 1, - "max": 5, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + migration_dir = tmp_path / "migrations" + + config = OracleAsyncConfig( + pool_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + "min": 1, + "max": 5, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - try: - await commands.init(str(migration_dir), package=True) + try: + await commands.init(str(migration_dir), package=True) - migration_content = '''"""Migration with invalid SQL.""" + migration_content = '''"""Migration with invalid SQL.""" def up(): @@ -615,44 +598,43 @@ def down(): """Drop table.""" return ["DROP TABLE invalid_table"] ''' - (migration_dir / "0001_invalid.py").write_text(migration_content) + (migration_dir / "0001_invalid.py").write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - async with config.provide_session() as driver: - count = await driver.select_value(f"SELECT COUNT(*) FROM {migration_table}") - assert count == 0, f"Expected empty migration table after failed migration, but found {count} records" - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + count = await driver.select_value(f"SELECT COUNT(*) FROM {migration_table}") + assert count == 0, f"Expected empty migration table after failed migration, but found {count} records" + finally: + if config.pool_instance: + await config.close_pool() -def test_oracledb_sync_migration_with_transactions(oracle_23ai_service: OracleService) -> None: +def test_oracledb_sync_migration_with_transactions(tmp_path: Path, oracle_23ai_service: OracleService) -> None: """Test OracleDB sync migrations work properly with transactions.""" test_id = "oracledb_sync_trans" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = OracleSyncConfig( - pool_config={ - "host": oracle_23ai_service.host, - "port": oracle_23ai_service.port, - "service_name": oracle_23ai_service.service_name, - "user": oracle_23ai_service.user, - "password": oracle_23ai_service.password, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + migration_dir = tmp_path / "migrations" + + config = OracleSyncConfig( + pool_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - try: - commands.init(str(migration_dir), package=True) + try: + commands.init(str(migration_dir), package=True) - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -670,75 +652,74 @@ def down(): """Drop users table.""" return ["DROP TABLE {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration_content) + (migration_dir / "0001_create_users.py").write_text(migration_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - driver.begin() - try: - driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", - ("Transaction User", "trans@example.com"), - ) - - result = driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") - assert len(result.data) == 1 - driver.commit() - except Exception: - driver.rollback() - raise + with config.provide_session() as driver: + driver.begin() + try: + driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", + ("Transaction User", "trans@example.com"), + ) result = driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") assert len(result.data) == 1 + driver.commit() + except Exception: + driver.rollback() + raise - with config.provide_session() as driver: - driver.begin() - try: - driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", - ("Rollback User", "rollback@example.com"), - ) + result = driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") + assert len(result.data) == 1 - raise Exception("Intentional rollback") - except Exception: - driver.rollback() + with config.provide_session() as driver: + driver.begin() + try: + driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", + ("Rollback User", "rollback@example.com"), + ) - result = driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Rollback User'") - assert len(result.data) == 0 - finally: - if config.pool_instance: - config.close_pool() + raise Exception("Intentional rollback") + except Exception: + driver.rollback() + + result = driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Rollback User'") + assert len(result.data) == 0 + finally: + if config.pool_instance: + config.close_pool() -async def test_oracledb_async_migration_with_transactions(oracle_23ai_service: OracleService) -> None: +async def test_oracledb_async_migration_with_transactions(tmp_path: Path, oracle_23ai_service: OracleService) -> None: """Test OracleDB async migrations work properly with transactions.""" test_id = "oracledb_async_trans" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = OracleAsyncConfig( - pool_config={ - "host": oracle_23ai_service.host, - "port": oracle_23ai_service.port, - "service_name": oracle_23ai_service.service_name, - "user": oracle_23ai_service.user, - "password": oracle_23ai_service.password, - "min": 1, - "max": 5, - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + migration_dir = tmp_path / "migrations" - try: - await commands.init(str(migration_dir), package=True) + config = OracleAsyncConfig( + pool_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + "min": 1, + "max": 5, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - migration_content = f'''"""Initial schema migration.""" + try: + await commands.init(str(migration_dir), package=True) + + migration_content = f'''"""Initial schema migration.""" def up(): @@ -756,48 +737,50 @@ def down(): """Drop users table.""" return ["DROP TABLE {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration_content) - - await commands.upgrade() + (migration_dir / "0001_create_users.py").write_text(migration_content) - async with config.provide_session() as driver: - await driver.begin() - try: - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", - ("Transaction User", "trans@example.com"), - ) + await commands.upgrade() - result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") - assert len(result.data) == 1 - await driver.commit() - except Exception: - await driver.rollback() - raise + async with config.provide_session() as driver: + await driver.begin() + try: + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", + ("Transaction User", "trans@example.com"), + ) result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") assert len(result.data) == 1 + await driver.commit() + except Exception: + await driver.rollback() + raise - async with config.provide_session() as driver: - await driver.begin() - try: - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", - ("Rollback User", "rollback@example.com"), - ) + result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") + assert len(result.data) == 1 - raise Exception("Intentional rollback") - except Exception: - await driver.rollback() + async with config.provide_session() as driver: + await driver.begin() + try: + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (:1, :2)", + ("Rollback User", "rollback@example.com"), + ) + + raise Exception("Intentional rollback") + except Exception: + await driver.rollback() - result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Rollback User'") - assert len(result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Rollback User'") + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() -async def test_oracledb_async_schema_migration_from_old_format(oracle_23ai_service: OracleService) -> None: +async def test_oracledb_async_schema_migration_from_old_format( + tmp_path: Path, oracle_23ai_service: OracleService +) -> None: """Test automatic schema migration from old format (without execution_sequence) to new format. This simulates the scenario where a user has an existing database with the old schema @@ -872,7 +855,7 @@ async def test_oracledb_async_schema_migration_from_old_format(oracle_23ai_servi await config.close_pool() -def test_oracledb_sync_schema_migration_from_old_format(oracle_23ai_service: OracleService) -> None: +def test_oracledb_sync_schema_migration_from_old_format(tmp_path: Path, oracle_23ai_service: OracleService) -> None: """Test automatic schema migration from old format (without execution_sequence) to new format (sync version). This simulates the scenario where a user has an existing database with the old schema diff --git a/tests/integration/test_adapters/test_psqlpy/test_migrations.py b/tests/integration/test_adapters/test_psqlpy/test_migrations.py index e60722e5..5d0275db 100644 --- a/tests/integration/test_adapters/test_psqlpy/test_migrations.py +++ b/tests/integration/test_adapters/test_psqlpy/test_migrations.py @@ -1,6 +1,5 @@ """Integration tests for Psqlpy (PostgreSQL) migration workflow.""" -import tempfile from pathlib import Path import pytest @@ -12,30 +11,29 @@ pytestmark = pytest.mark.xdist_group("postgres") -async def test_psqlpy_migration_full_workflow(postgres_service: PostgresService) -> None: +async def test_psqlpy_migration_full_workflow(tmp_path: Path, postgres_service: PostgresService) -> None: """Test full Psqlpy migration workflow: init -> create -> upgrade -> downgrade.""" test_id = "psqlpy_full_workflow" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = PsqlpyConfig( - pool_config={ - "dsn": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + config = PsqlpyConfig( + pool_config={ + "dsn": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - await commands.init(str(migration_dir), package=True) + await commands.init(str(migration_dir), package=True) - assert migration_dir.exists() - assert (migration_dir / "__init__.py").exists() + assert migration_dir.exists() + assert (migration_dir / "__init__.py").exists() - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -55,40 +53,40 @@ def down(): return ["DROP TABLE IF EXISTS {users_table}"] ''' - migration_file = migration_dir / "0001_create_users.py" - migration_file.write_text(migration_content) + migration_file = migration_dir / "0001_create_users.py" + migration_file.write_text(migration_content) - try: - await commands.upgrade() + try: + await commands.upgrade() - async with config.provide_session() as driver: - result = await driver.execute( - f"SELECT c.relname::text AS table_name FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname = '{users_table}' AND c.relkind = 'r'" - ) - assert len(result.data) == 1 + async with config.provide_session() as driver: + result = await driver.execute( + f"SELECT c.relname::text AS table_name FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname = '{users_table}' AND c.relkind = 'r'" + ) + assert len(result.data) == 1 - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES ($1, $2)", ("John Doe", "john@example.com") - ) + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES ($1, $2)", ("John Doe", "john@example.com") + ) - users_result = await driver.execute(f"SELECT * FROM {users_table}") - assert len(users_result.data) == 1 - assert users_result.data[0]["name"] == "John Doe" - assert users_result.data[0]["email"] == "john@example.com" + users_result = await driver.execute(f"SELECT * FROM {users_table}") + assert len(users_result.data) == 1 + assert users_result.data[0]["name"] == "John Doe" + assert users_result.data[0]["email"] == "john@example.com" - await commands.downgrade("base") + await commands.downgrade("base") - async with config.provide_session() as driver: - result = await driver.execute( - f"SELECT c.relname::text AS table_name FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname = '{users_table}' AND c.relkind = 'r'" - ) - assert len(result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + result = await driver.execute( + f"SELECT c.relname::text AS table_name FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname = '{users_table}' AND c.relkind = 'r'" + ) + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() -async def test_psqlpy_multiple_migrations_workflow(postgres_service: PostgresService) -> None: +async def test_psqlpy_multiple_migrations_workflow(tmp_path: Path, postgres_service: PostgresService) -> None: """Test Psqlpy workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all.""" test_id = "psqlpy_multi_workflow" @@ -96,20 +94,19 @@ async def test_psqlpy_multiple_migrations_workflow(postgres_service: PostgresSer users_table = f"users_{test_id}" posts_table = f"posts_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = PsqlpyConfig( - pool_config={ - "dsn": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + config = PsqlpyConfig( + pool_config={ + "dsn": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - await commands.init(str(migration_dir), package=True) + await commands.init(str(migration_dir), package=True) - migration1_content = f'''"""Create users table.""" + migration1_content = f'''"""Create users table.""" def up(): @@ -128,9 +125,9 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration1_content) + (migration_dir / "0001_create_users.py").write_text(migration1_content) - migration2_content = f'''"""Create posts table.""" + migration2_content = f'''"""Create posts table.""" def up(): @@ -150,78 +147,77 @@ def down(): """Drop posts table.""" return ["DROP TABLE IF EXISTS {posts_table}"] ''' - (migration_dir / "0002_create_posts.py").write_text(migration2_content) - - try: - await commands.upgrade() - - async with config.provide_session() as driver: - users_result = await driver.execute( - f"SELECT c.relname::text AS table_name FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname = '{users_table}' AND c.relkind = 'r'" - ) - posts_result = await driver.execute( - f"SELECT c.relname::text AS table_name FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname = '{posts_table}' AND c.relkind = 'r'" - ) - assert len(users_result.data) == 1 - assert len(posts_result.data) == 1 - - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES ($1, $2)", ("John Doe", "john@example.com") - ) - await driver.execute( - f"INSERT INTO {posts_table} (title, content, user_id) VALUES ($1, $2, $3)", - ("Test Post", "This is a test post", 1), - ) - - await commands.downgrade("0001") - - async with config.provide_session() as driver: - users_result = await driver.execute( - f"SELECT c.relname::text AS table_name FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname = '{users_table}' AND c.relkind = 'r'" - ) - posts_result = await driver.execute( - f"SELECT c.relname::text AS table_name FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname = '{posts_table}' AND c.relkind = 'r'" - ) - assert len(users_result.data) == 1 - assert len(posts_result.data) == 0 - - await commands.downgrade("base") - - async with config.provide_session() as driver: - users_result = await driver.execute( - f"SELECT c.relname::text AS table_name FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname IN ('{users_table}', '{posts_table}') AND c.relkind = 'r'" - ) - assert len(users_result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() - - -async def test_psqlpy_migration_current_command(postgres_service: PostgresService) -> None: + (migration_dir / "0002_create_posts.py").write_text(migration2_content) + + try: + await commands.upgrade() + + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT c.relname::text AS table_name FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname = '{users_table}' AND c.relkind = 'r'" + ) + posts_result = await driver.execute( + f"SELECT c.relname::text AS table_name FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname = '{posts_table}' AND c.relkind = 'r'" + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 1 + + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES ($1, $2)", ("John Doe", "john@example.com") + ) + await driver.execute( + f"INSERT INTO {posts_table} (title, content, user_id) VALUES ($1, $2, $3)", + ("Test Post", "This is a test post", 1), + ) + + await commands.downgrade("0001") + + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT c.relname::text AS table_name FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname = '{users_table}' AND c.relkind = 'r'" + ) + posts_result = await driver.execute( + f"SELECT c.relname::text AS table_name FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname = '{posts_table}' AND c.relkind = 'r'" + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 0 + + await commands.downgrade("base") + + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT c.relname::text AS table_name FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname IN ('{users_table}', '{posts_table}') AND c.relkind = 'r'" + ) + assert len(users_result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() + + +async def test_psqlpy_migration_current_command(tmp_path: Path, postgres_service: PostgresService) -> None: """Test the current migration command shows correct version for Psqlpy.""" test_id = "psqlpy_current_cmd" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = PsqlpyConfig( - pool_config={ - "dsn": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + config = PsqlpyConfig( + pool_config={ + "dsn": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - try: - await commands.init(str(migration_dir), package=True) + try: + await commands.init(str(migration_dir), package=True) - current_version = await commands.current() - assert current_version is None or current_version == "base" + current_version = await commands.current() + assert current_version is None or current_version == "base" - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -238,39 +234,38 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration_content) + (migration_dir / "0001_create_users.py").write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - current_version = await commands.current() - assert current_version == "0001" + current_version = await commands.current() + assert current_version == "0001" - await commands.downgrade("base") + await commands.downgrade("base") - current_version = await commands.current() - assert current_version is None or current_version == "base" - finally: - if config.pool_instance: - await config.close_pool() + current_version = await commands.current() + assert current_version is None or current_version == "base" + finally: + if config.pool_instance: + await config.close_pool() -async def test_psqlpy_migration_error_handling(postgres_service: PostgresService) -> None: +async def test_psqlpy_migration_error_handling(tmp_path: Path, postgres_service: PostgresService) -> None: """Test Psqlpy migration error handling.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = PsqlpyConfig( - pool_config={ - "dsn": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_psqlpy"}, - ) - commands = AsyncMigrationCommands(config) + migration_dir = tmp_path / "migrations" - try: - await commands.init(str(migration_dir), package=True) + config = PsqlpyConfig( + pool_config={ + "dsn": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_psqlpy"}, + ) + commands = AsyncMigrationCommands(config) + + try: + await commands.init(str(migration_dir), package=True) - migration_content = '''"""Migration with invalid SQL.""" + migration_content = '''"""Migration with invalid SQL.""" def up(): @@ -282,44 +277,43 @@ def down(): """Drop table.""" return ["DROP TABLE IF EXISTS invalid_table"] ''' - (migration_dir / "0001_invalid.py").write_text(migration_content) + (migration_dir / "0001_invalid.py").write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - async with config.provide_session() as driver: - try: - await driver.execute("SELECT version FROM sqlspec_migrations_psqlpy ORDER BY version") - msg = "Expected migration table to not exist, but it does" - raise AssertionError(msg) - except Exception as e: - assert "no such" in str(e).lower() or "does not exist" in str(e).lower() - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + try: + await driver.execute("SELECT version FROM sqlspec_migrations_psqlpy ORDER BY version") + msg = "Expected migration table to not exist, but it does" + raise AssertionError(msg) + except Exception as e: + assert "no such" in str(e).lower() or "does not exist" in str(e).lower() + finally: + if config.pool_instance: + await config.close_pool() -async def test_psqlpy_migration_with_transactions(postgres_service: PostgresService) -> None: +async def test_psqlpy_migration_with_transactions(tmp_path: Path, postgres_service: PostgresService) -> None: """Test Psqlpy migrations work properly with transactions.""" test_id = "psqlpy_transactions" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = PsqlpyConfig( - pool_config={ - "dsn": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + config = PsqlpyConfig( + pool_config={ + "dsn": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - try: - await commands.init(str(migration_dir), package=True) + try: + await commands.init(str(migration_dir), package=True) - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -337,42 +331,42 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration_content) + (migration_dir / "0001_create_users.py").write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - async with config.provide_session() as driver: - await driver.begin() - try: - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES ($1, $2)", - ("Transaction User", "trans@example.com"), - ) - - result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") - assert len(result.data) == 1 - await driver.commit() - except Exception: - await driver.rollback() - raise + async with config.provide_session() as driver: + await driver.begin() + try: + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES ($1, $2)", + ("Transaction User", "trans@example.com"), + ) result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") assert len(result.data) == 1 + await driver.commit() + except Exception: + await driver.rollback() + raise + + result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") + assert len(result.data) == 1 + + async with config.provide_session() as driver: + await driver.begin() + try: + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES ($1, $2)", + ("Rollback User", "rollback@example.com"), + ) + + raise Exception("Intentional rollback") + except Exception: + await driver.rollback() - async with config.provide_session() as driver: - await driver.begin() - try: - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES ($1, $2)", - ("Rollback User", "rollback@example.com"), - ) - - raise Exception("Intentional rollback") - except Exception: - await driver.rollback() - - result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Rollback User'") - assert len(result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Rollback User'") + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() diff --git a/tests/integration/test_adapters/test_psycopg/test_migrations.py b/tests/integration/test_adapters/test_psycopg/test_migrations.py index 680d0a1c..00d6b01d 100644 --- a/tests/integration/test_adapters/test_psycopg/test_migrations.py +++ b/tests/integration/test_adapters/test_psycopg/test_migrations.py @@ -1,6 +1,5 @@ """Integration tests for Psycopg (PostgreSQL) migration workflow.""" -import tempfile from pathlib import Path from typing import Any @@ -14,30 +13,29 @@ pytestmark = pytest.mark.xdist_group("postgres") -def test_psycopg_sync_migration_full_workflow(postgres_service: PostgresService) -> None: +def test_psycopg_sync_migration_full_workflow(tmp_path: Path, postgres_service: PostgresService) -> None: """Test full Psycopg sync migration workflow: init -> create -> upgrade -> downgrade.""" test_id = "psycopg_sync_full" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = PsycopgSyncConfig( - pool_config={ - "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + config = PsycopgSyncConfig( + pool_config={ + "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - assert migration_dir.exists() - assert (migration_dir / "__init__.py").exists() + assert migration_dir.exists() + assert (migration_dir / "__init__.py").exists() - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -57,63 +55,60 @@ def down(): return ["DROP TABLE IF EXISTS {users_table}"] ''' - migration_file = migration_dir / "0001_create_users.py" - migration_file.write_text(migration_content) + migration_file = migration_dir / "0001_create_users.py" + migration_file.write_text(migration_content) - try: - commands.upgrade() + try: + commands.upgrade() - with config.provide_session() as driver: - result = driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" - ) - assert len(result.data) == 1 + with config.provide_session() as driver: + result = driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" + ) + assert len(result.data) == 1 - driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", ("John Doe", "john@example.com") - ) + driver.execute(f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", ("John Doe", "john@example.com")) - users_result = driver.execute(f"SELECT * FROM {users_table}") - assert len(users_result.data) == 1 - assert users_result.data[0]["name"] == "John Doe" - assert users_result.data[0]["email"] == "john@example.com" + users_result = driver.execute(f"SELECT * FROM {users_table}") + assert len(users_result.data) == 1 + assert users_result.data[0]["name"] == "John Doe" + assert users_result.data[0]["email"] == "john@example.com" - commands.downgrade("base") + commands.downgrade("base") - with config.provide_session() as driver: - result = driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" - ) - assert len(result.data) == 0 - finally: - if config.pool_instance: - config.close_pool() + with config.provide_session() as driver: + result = driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" + ) + assert len(result.data) == 0 + finally: + if config.pool_instance: + config.close_pool() -async def test_psycopg_async_migration_full_workflow(postgres_service: PostgresService) -> None: +async def test_psycopg_async_migration_full_workflow(tmp_path: Path, postgres_service: PostgresService) -> None: """Test full Psycopg async migration workflow: init -> create -> upgrade -> downgrade.""" test_id = "psycopg_async_full" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = PsycopgAsyncConfig( - pool_config={ - "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + config = PsycopgAsyncConfig( + pool_config={ + "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - await commands.init(str(migration_dir), package=True) + await commands.init(str(migration_dir), package=True) - assert migration_dir.exists() - assert (migration_dir / "__init__.py").exists() + assert migration_dir.exists() + assert (migration_dir / "__init__.py").exists() - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -133,40 +128,40 @@ def down(): return ["DROP TABLE IF EXISTS {users_table}"] ''' - migration_file = migration_dir / "0001_create_users.py" - migration_file.write_text(migration_content) + migration_file = migration_dir / "0001_create_users.py" + migration_file.write_text(migration_content) - try: - await commands.upgrade() + try: + await commands.upgrade() - async with config.provide_session() as driver: - result = await driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" - ) - assert len(result.data) == 1 + async with config.provide_session() as driver: + result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" + ) + assert len(result.data) == 1 - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", ("John Doe", "john@example.com") - ) + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", ("John Doe", "john@example.com") + ) - users_result = await driver.execute(f"SELECT * FROM {users_table}") - assert len(users_result.data) == 1 - assert users_result.data[0]["name"] == "John Doe" - assert users_result.data[0]["email"] == "john@example.com" + users_result = await driver.execute(f"SELECT * FROM {users_table}") + assert len(users_result.data) == 1 + assert users_result.data[0]["name"] == "John Doe" + assert users_result.data[0]["email"] == "john@example.com" - await commands.downgrade("base") + await commands.downgrade("base") - async with config.provide_session() as driver: - result = await driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" - ) - assert len(result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" + ) + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() -def test_psycopg_sync_multiple_migrations_workflow(postgres_service: PostgresService) -> None: +def test_psycopg_sync_multiple_migrations_workflow(tmp_path: Path, postgres_service: PostgresService) -> None: """Test Psycopg sync workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all.""" test_id = "psycopg_sync_multi" @@ -174,20 +169,19 @@ def test_psycopg_sync_multiple_migrations_workflow(postgres_service: PostgresSer users_table = f"users_{test_id}" posts_table = f"posts_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = PsycopgSyncConfig( - pool_config={ - "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + config = PsycopgSyncConfig( + pool_config={ + "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - migration1_content = f'''"""Create users table.""" + migration1_content = f'''"""Create users table.""" def up(): @@ -206,9 +200,9 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration1_content) + (migration_dir / "0001_create_users.py").write_text(migration1_content) - migration2_content = f'''"""Create posts table.""" + migration2_content = f'''"""Create posts table.""" def up(): @@ -228,46 +222,46 @@ def down(): """Drop posts table.""" return ["DROP TABLE IF EXISTS {posts_table}"] ''' - (migration_dir / "0002_create_posts.py").write_text(migration2_content) - - try: - commands.upgrade() - - with config.provide_session() as driver: - users_result = driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" - ) - posts_result = driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{posts_table}'" - ) - assert len(users_result.data) == 1 - assert len(posts_result.data) == 1 - - commands.downgrade("0001") - - with config.provide_session() as driver: - users_result = driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" - ) - posts_result = driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{posts_table}'" - ) - assert len(users_result.data) == 1 - assert len(posts_result.data) == 0 - - commands.downgrade("base") - - with config.provide_session() as driver: - users_result = driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name IN ('{users_table}', '{posts_table}')" - ) - assert len(users_result.data) == 0 - finally: - if config.pool_instance: - config.close_pool() - - -async def test_psycopg_async_multiple_migrations_workflow(postgres_service: PostgresService) -> None: + (migration_dir / "0002_create_posts.py").write_text(migration2_content) + + try: + commands.upgrade() + + with config.provide_session() as driver: + users_result = driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" + ) + posts_result = driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{posts_table}'" + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 1 + + commands.downgrade("0001") + + with config.provide_session() as driver: + users_result = driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" + ) + posts_result = driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{posts_table}'" + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 0 + + commands.downgrade("base") + + with config.provide_session() as driver: + users_result = driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name IN ('{users_table}', '{posts_table}')" + ) + assert len(users_result.data) == 0 + finally: + if config.pool_instance: + config.close_pool() + + +async def test_psycopg_async_multiple_migrations_workflow(tmp_path: Path, postgres_service: PostgresService) -> None: """Test Psycopg async workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all.""" test_id = "psycopg_async_multi" @@ -275,25 +269,24 @@ async def test_psycopg_async_multiple_migrations_workflow(postgres_service: Post users_table = f"users_{test_id}" posts_table = f"posts_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - try: - from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig - except ImportError: - pytest.skip("PsycopgAsyncConfig not available") + try: + from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig + except ImportError: + pytest.skip("PsycopgAsyncConfig not available") - config = PsycopgAsyncConfig( - pool_config={ - "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + config = PsycopgAsyncConfig( + pool_config={ + "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - await commands.init(str(migration_dir), package=True) + await commands.init(str(migration_dir), package=True) - migration1_content = f'''"""Create users table.""" + migration1_content = f'''"""Create users table.""" def up(): @@ -312,9 +305,9 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration1_content) + (migration_dir / "0001_create_users.py").write_text(migration1_content) - migration2_content = f'''"""Create posts table.""" + migration2_content = f'''"""Create posts table.""" def up(): @@ -334,70 +327,69 @@ def down(): """Drop posts table.""" return ["DROP TABLE IF EXISTS {posts_table}"] ''' - (migration_dir / "0002_create_posts.py").write_text(migration2_content) - - try: - await commands.upgrade() - - async with config.provide_session() as driver: - users_result = await driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" - ) - posts_result = await driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{posts_table}'" - ) - assert len(users_result.data) == 1 - assert len(posts_result.data) == 1 - - await commands.downgrade("0001") - - async with config.provide_session() as driver: - users_result = await driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" - ) - posts_result = await driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{posts_table}'" - ) - assert len(users_result.data) == 1 - assert len(posts_result.data) == 0 - - await commands.downgrade("base") - - async with config.provide_session() as driver: - users_result = await driver.execute( - f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name IN ('{users_table}', '{posts_table}')" - ) - assert len(users_result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() - - -def test_psycopg_sync_migration_current_command(postgres_service: PostgresService) -> None: + (migration_dir / "0002_create_posts.py").write_text(migration2_content) + + try: + await commands.upgrade() + + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" + ) + posts_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{posts_table}'" + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 1 + + await commands.downgrade("0001") + + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{users_table}'" + ) + posts_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{posts_table}'" + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 0 + + await commands.downgrade("base") + + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name IN ('{users_table}', '{posts_table}')" + ) + assert len(users_result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() + + +def test_psycopg_sync_migration_current_command(tmp_path: Path, postgres_service: PostgresService) -> None: """Test the current migration command shows correct version for Psycopg sync.""" test_id = "psycopg_sync_current" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = PsycopgSyncConfig( - pool_config={ - "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + config = PsycopgSyncConfig( + pool_config={ + "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - try: - commands.init(str(migration_dir), package=True) + try: + commands.init(str(migration_dir), package=True) - current_version = commands.current() - assert current_version is None or current_version == "base" + current_version = commands.current() + assert current_version is None or current_version == "base" - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -414,52 +406,51 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration_content) + (migration_dir / "0001_create_users.py").write_text(migration_content) - commands.upgrade() + commands.upgrade() - current_version = commands.current() - assert current_version == "0001" + current_version = commands.current() + assert current_version == "0001" - commands.downgrade("base") + commands.downgrade("base") - current_version = commands.current() - assert current_version is None or current_version == "base" - finally: - if config.pool_instance: - config.close_pool() + current_version = commands.current() + assert current_version is None or current_version == "base" + finally: + if config.pool_instance: + config.close_pool() -async def test_psycopg_async_migration_current_command(postgres_service: PostgresService) -> None: +async def test_psycopg_async_migration_current_command(tmp_path: Path, postgres_service: PostgresService) -> None: """Test the current migration command shows correct version for Psycopg async.""" test_id = "psycopg_async_current" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - try: - from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig - except ImportError: - pytest.skip("PsycopgAsyncConfig not available") + try: + from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig + except ImportError: + pytest.skip("PsycopgAsyncConfig not available") - config = PsycopgAsyncConfig( - pool_config={ - "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + config = PsycopgAsyncConfig( + pool_config={ + "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - try: - await commands.init(str(migration_dir), package=True) + try: + await commands.init(str(migration_dir), package=True) - current_version = await commands.current() - assert current_version is None or current_version == "base" + current_version = await commands.current() + assert current_version is None or current_version == "base" - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -476,42 +467,41 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration_content) + (migration_dir / "0001_create_users.py").write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - current_version = await commands.current() - assert current_version == "0001" + current_version = await commands.current() + assert current_version == "0001" - await commands.downgrade("base") + await commands.downgrade("base") - current_version = await commands.current() - assert current_version is None or current_version == "base" - finally: - if config.pool_instance: - await config.close_pool() + current_version = await commands.current() + assert current_version is None or current_version == "base" + finally: + if config.pool_instance: + await config.close_pool() -def test_psycopg_sync_migration_error_handling(postgres_service: PostgresService) -> None: +def test_psycopg_sync_migration_error_handling(tmp_path: Path, postgres_service: PostgresService) -> None: """Test Psycopg sync migration error handling.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - config = PsycopgSyncConfig( - pool_config={ - "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={ - "script_location": str(migration_dir), - "version_table_name": "sqlspec_migrations_psycopg_sync_error", - }, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - - try: - commands.init(str(migration_dir), package=True) + migration_dir = tmp_path / "migrations" + + config = PsycopgSyncConfig( + pool_config={ + "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": "sqlspec_migrations_psycopg_sync_error", + }, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + + try: + commands.init(str(migration_dir), package=True) - migration_content = '''"""Migration with invalid SQL.""" + migration_content = '''"""Migration with invalid SQL.""" def up(): @@ -523,47 +513,46 @@ def down(): """Drop table.""" return ["DROP TABLE IF EXISTS invalid_table"] ''' - (migration_dir / "0001_invalid.py").write_text(migration_content) + (migration_dir / "0001_invalid.py").write_text(migration_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - try: - driver.execute("SELECT version FROM sqlspec_migrations_psycopg_sync_error ORDER BY version") - msg = "Expected migration table to not exist, but it does" - raise AssertionError(msg) - except Exception as e: - assert "no such" in str(e).lower() or "does not exist" in str(e).lower() - finally: - if config.pool_instance: - config.close_pool() + with config.provide_session() as driver: + try: + driver.execute("SELECT version FROM sqlspec_migrations_psycopg_sync_error ORDER BY version") + msg = "Expected migration table to not exist, but it does" + raise AssertionError(msg) + except Exception as e: + assert "no such" in str(e).lower() or "does not exist" in str(e).lower() + finally: + if config.pool_instance: + config.close_pool() -async def test_psycopg_async_migration_error_handling(postgres_service: PostgresService) -> None: +async def test_psycopg_async_migration_error_handling(tmp_path: Path, postgres_service: PostgresService) -> None: """Test Psycopg async migration error handling.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - try: - from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig - except ImportError: - pytest.skip("PsycopgAsyncConfig not available") - - config = PsycopgAsyncConfig( - pool_config={ - "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={ - "script_location": str(migration_dir), - "version_table_name": "sqlspec_migrations_psycopg_async_error", - }, - ) - commands = AsyncMigrationCommands(config) - - try: - await commands.init(str(migration_dir), package=True) + migration_dir = tmp_path / "migrations" + + try: + from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig + except ImportError: + pytest.skip("PsycopgAsyncConfig not available") + + config = PsycopgAsyncConfig( + pool_config={ + "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": "sqlspec_migrations_psycopg_async_error", + }, + ) + commands = AsyncMigrationCommands(config) + + try: + await commands.init(str(migration_dir), package=True) - migration_content = '''"""Migration with invalid SQL.""" + migration_content = '''"""Migration with invalid SQL.""" def up(): @@ -575,44 +564,43 @@ def down(): """Drop table.""" return ["DROP TABLE IF EXISTS invalid_table"] ''' - (migration_dir / "0001_invalid.py").write_text(migration_content) + (migration_dir / "0001_invalid.py").write_text(migration_content) - await commands.upgrade() + await commands.upgrade() - async with config.provide_session() as driver: - try: - await driver.execute("SELECT version FROM sqlspec_migrations_psycopg_async_error ORDER BY version") - msg = "Expected migration table to not exist, but it does" - raise AssertionError(msg) - except Exception as e: - assert "no such" in str(e).lower() or "does not exist" in str(e).lower() - finally: - if config.pool_instance: - await config.close_pool() + async with config.provide_session() as driver: + try: + await driver.execute("SELECT version FROM sqlspec_migrations_psycopg_async_error ORDER BY version") + msg = "Expected migration table to not exist, but it does" + raise AssertionError(msg) + except Exception as e: + assert "no such" in str(e).lower() or "does not exist" in str(e).lower() + finally: + if config.pool_instance: + await config.close_pool() -def test_psycopg_sync_migration_with_transactions(postgres_service: PostgresService) -> None: +def test_psycopg_sync_migration_with_transactions(tmp_path: Path, postgres_service: PostgresService) -> None: """Test Psycopg sync migrations work properly with transactions.""" test_id = "psycopg_sync_trans" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = PsycopgSyncConfig( - pool_config={ - "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + config = PsycopgSyncConfig( + pool_config={ + "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - try: - commands.init(str(migration_dir), package=True) + try: + commands.init(str(migration_dir), package=True) - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -630,74 +618,73 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration_content) - - commands.upgrade() + (migration_dir / "0001_create_users.py").write_text(migration_content) - with config.provide_session() as driver: - driver.begin() - try: - driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", - ("Transaction User", "trans@example.com"), - ) + commands.upgrade() - result = driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") - assert len(result.data) == 1 - driver.commit() - except Exception: - driver.rollback() - raise + with config.provide_session() as driver: + driver.begin() + try: + driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", + ("Transaction User", "trans@example.com"), + ) result = driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") assert len(result.data) == 1 + driver.commit() + except Exception: + driver.rollback() + raise - with config.provide_session() as driver: - driver.begin() - try: - driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", - ("Rollback User", "rollback@example.com"), - ) + result = driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") + assert len(result.data) == 1 - raise Exception("Intentional rollback") - except Exception: - driver.rollback() + with config.provide_session() as driver: + driver.begin() + try: + driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", + ("Rollback User", "rollback@example.com"), + ) + + raise Exception("Intentional rollback") + except Exception: + driver.rollback() - result = driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Rollback User'") - assert len(result.data) == 0 - finally: - if config.pool_instance: - config.close_pool() + result = driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Rollback User'") + assert len(result.data) == 0 + finally: + if config.pool_instance: + config.close_pool() -async def test_psycopg_async_migration_with_transactions(postgres_service: PostgresService) -> None: +async def test_psycopg_async_migration_with_transactions(tmp_path: Path, postgres_service: PostgresService) -> None: """Test Psycopg async migrations work properly with transactions.""" test_id = "psycopg_async_trans" migration_table = f"sqlspec_migrations_{test_id}" users_table = f"users_{test_id}" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - try: - from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig - except ImportError: - pytest.skip("PsycopgAsyncConfig not available") + try: + from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig + except ImportError: + pytest.skip("PsycopgAsyncConfig not available") - config = PsycopgAsyncConfig( - pool_config={ - "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" - }, - migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, - ) - commands = AsyncMigrationCommands(config) + config = PsycopgAsyncConfig( + pool_config={ + "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) - try: - await commands.init(str(migration_dir), package=True) + try: + await commands.init(str(migration_dir), package=True) - migration_content = f'''"""Initial schema migration.""" + migration_content = f'''"""Initial schema migration.""" def up(): @@ -715,42 +702,42 @@ def down(): """Drop users table.""" return ["DROP TABLE IF EXISTS {users_table}"] ''' - (migration_dir / "0001_create_users.py").write_text(migration_content) - - await commands.upgrade() + (migration_dir / "0001_create_users.py").write_text(migration_content) - async with config.provide_session() as driver: - await driver.begin() - try: - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", - ("Transaction User", "trans@example.com"), - ) + await commands.upgrade() - result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") - assert len(result.data) == 1 - await driver.commit() - except Exception: - await driver.rollback() - raise + async with config.provide_session() as driver: + await driver.begin() + try: + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", + ("Transaction User", "trans@example.com"), + ) result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") assert len(result.data) == 1 + await driver.commit() + except Exception: + await driver.rollback() + raise + + result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") + assert len(result.data) == 1 + + async with config.provide_session() as driver: + await driver.begin() + try: + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", + ("Rollback User", "rollback@example.com"), + ) + + raise Exception("Intentional rollback") + except Exception: + await driver.rollback() - async with config.provide_session() as driver: - await driver.begin() - try: - await driver.execute( - f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", - ("Rollback User", "rollback@example.com"), - ) - - raise Exception("Intentional rollback") - except Exception: - await driver.rollback() - - result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Rollback User'") - assert len(result.data) == 0 - finally: - if config.pool_instance: - await config.close_pool() + result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Rollback User'") + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() diff --git a/tests/integration/test_adapters/test_sqlite/test_migrations.py b/tests/integration/test_adapters/test_sqlite/test_migrations.py index ded4b14a..91046aaa 100644 --- a/tests/integration/test_adapters/test_sqlite/test_migrations.py +++ b/tests/integration/test_adapters/test_sqlite/test_migrations.py @@ -1,6 +1,5 @@ """Integration tests for SQLite migration workflow.""" -import tempfile from pathlib import Path from typing import Any @@ -12,24 +11,22 @@ pytestmark = pytest.mark.xdist_group("sqlite") -def test_sqlite_migration_full_workflow() -> None: +def test_sqlite_migration_full_workflow(tmp_path: Path) -> None: """Test full SQLite migration workflow: init -> create -> upgrade -> downgrade.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - temp_db = str(Path(temp_dir) / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - assert migration_dir.exists() - assert (migration_dir / "__init__.py").exists() + assert migration_dir.exists() + assert (migration_dir / "__init__.py").exists() - migration_content = '''"""Initial schema migration.""" + migration_content = '''"""Initial schema migration.""" def up(): @@ -49,44 +46,42 @@ def down(): return ["DROP TABLE IF EXISTS users"] ''' - migration_file = migration_dir / "001_create_users.py" - migration_file.write_text(migration_content) + migration_file = migration_dir / "001_create_users.py" + migration_file.write_text(migration_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'") - assert len(result.data) == 1 + with config.provide_session() as driver: + result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'") + assert len(result.data) == 1 - driver.execute("INSERT INTO users (name, email) VALUES (?, ?)", ("John Doe", "john@example.com")) + driver.execute("INSERT INTO users (name, email) VALUES (?, ?)", ("John Doe", "john@example.com")) - users_result = driver.execute("SELECT * FROM users") - assert len(users_result.data) == 1 - assert users_result.data[0]["name"] == "John Doe" - assert users_result.data[0]["email"] == "john@example.com" + users_result = driver.execute("SELECT * FROM users") + assert len(users_result.data) == 1 + assert users_result.data[0]["name"] == "John Doe" + assert users_result.data[0]["email"] == "john@example.com" - commands.downgrade("base") + commands.downgrade("base") - with config.provide_session() as driver: - result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'") - assert len(result.data) == 0 + with config.provide_session() as driver: + result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='users'") + assert len(result.data) == 0 -def test_sqlite_multiple_migrations_workflow() -> None: +def test_sqlite_multiple_migrations_workflow(tmp_path: Path) -> None: """Test SQLite workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - temp_db = str(Path(temp_dir) / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - migration1_content = '''"""Create users table.""" + migration1_content = '''"""Create users table.""" def up(): @@ -105,7 +100,7 @@ def down(): return ["DROP TABLE IF EXISTS users"] ''' - migration2_content = '''"""Create posts table.""" + migration2_content = '''"""Create posts table.""" def up(): @@ -126,62 +121,56 @@ def down(): return ["DROP TABLE IF EXISTS posts"] ''' - (migration_dir / "0001_create_users.py").write_text(migration1_content) - (migration_dir / "0002_create_posts.py").write_text(migration2_content) + (migration_dir / "0001_create_users.py").write_text(migration1_content) + (migration_dir / "0002_create_posts.py").write_text(migration2_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - tables_result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") - table_names = [t["name"] for t in tables_result.data] - assert "users" in table_names - assert "posts" in table_names + with config.provide_session() as driver: + tables_result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") + table_names = [t["name"] for t in tables_result.data] + assert "users" in table_names + assert "posts" in table_names - driver.execute("INSERT INTO users (name, email) VALUES (?, ?)", ("Author", "author@example.com")) - driver.execute( - "INSERT INTO posts (title, content, user_id) VALUES (?, ?, ?)", ("My Post", "Post content", 1) - ) + driver.execute("INSERT INTO users (name, email) VALUES (?, ?)", ("Author", "author@example.com")) + driver.execute("INSERT INTO posts (title, content, user_id) VALUES (?, ?, ?)", ("My Post", "Post content", 1)) - posts_result = driver.execute("SELECT * FROM posts") - assert len(posts_result.data) == 1 - assert posts_result.data[0]["title"] == "My Post" + posts_result = driver.execute("SELECT * FROM posts") + assert len(posts_result.data) == 1 + assert posts_result.data[0]["title"] == "My Post" - commands.downgrade("0001") + commands.downgrade("0001") - with config.provide_session() as driver: - tables_result = driver.execute("SELECT name FROM sqlite_master WHERE type='table'") - table_names = [t["name"] for t in tables_result.data] - assert "users" in table_names - assert "posts" not in table_names + with config.provide_session() as driver: + tables_result = driver.execute("SELECT name FROM sqlite_master WHERE type='table'") + table_names = [t["name"] for t in tables_result.data] + assert "users" in table_names + assert "posts" not in table_names - commands.downgrade("base") + commands.downgrade("base") - with config.provide_session() as driver: - tables_result = driver.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'" - ) + with config.provide_session() as driver: + tables_result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") - table_names = [t["name"] for t in tables_result.data if not t["name"].startswith("sqlspec_")] - assert len(table_names) == 0 + table_names = [t["name"] for t in tables_result.data if not t["name"].startswith("sqlspec_")] + assert len(table_names) == 0 -def test_sqlite_migration_current_command() -> None: +def test_sqlite_migration_current_command(tmp_path: Path) -> None: """Test the current migration command shows correct version for SQLite.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - temp_db = str(Path(temp_dir) / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - commands.current(verbose=False) + commands.current(verbose=False) - migration_content = '''"""Test migration.""" + migration_content = '''"""Test migration.""" def up(): @@ -194,28 +183,26 @@ def down(): return ["DROP TABLE IF EXISTS test_table"] ''' - (migration_dir / "001_test.py").write_text(migration_content) + (migration_dir / "001_test.py").write_text(migration_content) - commands.upgrade() + commands.upgrade() - commands.current(verbose=True) + commands.current(verbose=True) -def test_sqlite_migration_error_handling() -> None: +def test_sqlite_migration_error_handling(tmp_path: Path) -> None: """Test SQLite migration error handling.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - - temp_db = str(Path(temp_dir) / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - migration_content = '''"""Bad migration.""" + migration_content = '''"""Bad migration.""" def up(): @@ -228,34 +215,32 @@ def down(): return [] ''' - (migration_dir / "001_bad.py").write_text(migration_content) + (migration_dir / "001_bad.py").write_text(migration_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - try: - driver.execute("SELECT version FROM sqlspec_migrations ORDER BY version") - msg = "Expected migration table to not exist, but it does" - raise AssertionError(msg) - except Exception as e: - assert "no such" in str(e).lower() or "does not exist" in str(e).lower() + with config.provide_session() as driver: + try: + driver.execute("SELECT version FROM sqlspec_migrations ORDER BY version") + msg = "Expected migration table to not exist, but it does" + raise AssertionError(msg) + except Exception as e: + assert "no such" in str(e).lower() or "does not exist" in str(e).lower() -def test_sqlite_migration_with_transactions() -> None: +def test_sqlite_migration_with_transactions(tmp_path: Path) -> None: """Test SQLite migrations work properly with transactions.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - temp_db = str(Path(temp_dir) / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) - commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) - - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - migration_content = '''"""Migration with multiple operations.""" + migration_content = '''"""Migration with multiple operations.""" def up(): @@ -275,37 +260,36 @@ def down(): return ["DROP TABLE IF EXISTS customers"] ''' - (migration_dir / "0001_transaction_test.py").write_text(migration_content) + (migration_dir / "0001_transaction_test.py").write_text(migration_content) - commands.upgrade() + commands.upgrade() - with config.provide_session() as driver: - customers_result = driver.execute("SELECT * FROM customers ORDER BY name") - assert len(customers_result.data) == 2 - assert customers_result.data[0]["name"] == "Customer 1" - assert customers_result.data[1]["name"] == "Customer 2" + with config.provide_session() as driver: + customers_result = driver.execute("SELECT * FROM customers ORDER BY name") + assert len(customers_result.data) == 2 + assert customers_result.data[0]["name"] == "Customer 1" + assert customers_result.data[1]["name"] == "Customer 2" - commands.downgrade("base") + commands.downgrade("base") - with config.provide_session() as driver: - result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='customers'") - assert len(result.data) == 0 + with config.provide_session() as driver: + result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='customers'") + assert len(result.data) == 0 -def test_sqlite_config_migrate_up_method() -> None: +def test_sqlite_config_migrate_up_method(tmp_path: Path) -> None: """Test SqliteConfig.migrate_up() method works correctly.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) - config.init_migrations() + config.init_migrations() - migration_content = '''"""Create products table.""" + migration_content = '''"""Create products table.""" def up(): @@ -324,29 +308,28 @@ def down(): return ["DROP TABLE IF EXISTS products"] ''' - (migration_dir / "0001_create_products.py").write_text(migration_content) + (migration_dir / "0001_create_products.py").write_text(migration_content) - config.migrate_up() + config.migrate_up() - with config.provide_session() as driver: - result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='products'") - assert len(result.data) == 1 + with config.provide_session() as driver: + result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='products'") + assert len(result.data) == 1 -def test_sqlite_config_migrate_down_method() -> None: +def test_sqlite_config_migrate_down_method(tmp_path: Path) -> None: """Test SqliteConfig.migrate_down() method works correctly.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) - config.init_migrations() + config.init_migrations() - migration_content = '''"""Create inventory table.""" + migration_content = '''"""Create inventory table.""" def up(): @@ -364,38 +347,37 @@ def down(): return ["DROP TABLE IF EXISTS inventory"] ''' - (migration_dir / "0001_create_inventory.py").write_text(migration_content) + (migration_dir / "0001_create_inventory.py").write_text(migration_content) - config.migrate_up() + config.migrate_up() - with config.provide_session() as driver: - result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='inventory'") - assert len(result.data) == 1 + with config.provide_session() as driver: + result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='inventory'") + assert len(result.data) == 1 - config.migrate_down() + config.migrate_down() - with config.provide_session() as driver: - result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='inventory'") - assert len(result.data) == 0 + with config.provide_session() as driver: + result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='inventory'") + assert len(result.data) == 0 -def test_sqlite_config_get_current_migration_method() -> None: +def test_sqlite_config_get_current_migration_method(tmp_path: Path) -> None: """Test SqliteConfig.get_current_migration() method returns correct version.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) - config.init_migrations() + config.init_migrations() - current_version = config.get_current_migration() - assert current_version is None + current_version = config.get_current_migration() + assert current_version is None - migration_content = '''"""First migration.""" + migration_content = '''"""First migration.""" def up(): @@ -408,50 +390,48 @@ def down(): return ["DROP TABLE IF EXISTS test_version"] ''' - (migration_dir / "0001_first.py").write_text(migration_content) + (migration_dir / "0001_first.py").write_text(migration_content) - config.migrate_up() + config.migrate_up() - current_version = config.get_current_migration() - assert current_version == "0001" + current_version = config.get_current_migration() + assert current_version == "0001" -def test_sqlite_config_create_migration_method() -> None: +def test_sqlite_config_create_migration_method(tmp_path: Path) -> None: """Test SqliteConfig.create_migration() method generates migration file.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) - config.init_migrations() + config.init_migrations() - config.create_migration("add users table", file_type="py") + config.create_migration("add users table", file_type="py") - migration_files = list(migration_dir.glob("*.py")) - migration_files = [f for f in migration_files if f.name != "__init__.py"] + migration_files = list(migration_dir.glob("*.py")) + migration_files = [f for f in migration_files if f.name != "__init__.py"] - assert len(migration_files) == 1 - assert "add_users_table" in migration_files[0].name + assert len(migration_files) == 1 + assert "add_users_table" in migration_files[0].name -def test_sqlite_config_stamp_migration_method() -> None: +def test_sqlite_config_stamp_migration_method(tmp_path: Path) -> None: """Test SqliteConfig.stamp_migration() method marks database at revision.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) - config.init_migrations() + config.init_migrations() - migration_content = '''"""Stamped migration.""" + migration_content = '''"""Stamped migration.""" def up(): @@ -464,32 +444,31 @@ def down(): return ["DROP TABLE IF EXISTS stamped"] ''' - (migration_dir / "0001_stamped.py").write_text(migration_content) + (migration_dir / "0001_stamped.py").write_text(migration_content) - config.stamp_migration("0001") + config.stamp_migration("0001") - current_version = config.get_current_migration() - assert current_version == "0001" + current_version = config.get_current_migration() + assert current_version == "0001" - with config.provide_session() as driver: - result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='stamped'") - assert len(result.data) == 0 + with config.provide_session() as driver: + result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='stamped'") + assert len(result.data) == 0 -def test_sqlite_config_fix_migrations_dry_run() -> None: +def test_sqlite_config_fix_migrations_dry_run(tmp_path: Path) -> None: """Test SqliteConfig.fix_migrations() dry run shows what would change.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, - ) + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) - config.init_migrations() + config.init_migrations() - timestamp_migration = '''"""Timestamp migration.""" + timestamp_migration = '''"""Timestamp migration.""" def up(): @@ -502,12 +481,12 @@ def down(): return ["DROP TABLE IF EXISTS timestamp_test"] ''' - (migration_dir / "20251030120000_timestamp_migration.py").write_text(timestamp_migration) + (migration_dir / "20251030120000_timestamp_migration.py").write_text(timestamp_migration) - config.fix_migrations(dry_run=True, yes=True) + config.fix_migrations(dry_run=True, yes=True) - timestamp_file = migration_dir / "20251030120000_timestamp_migration.py" - assert timestamp_file.exists() + timestamp_file = migration_dir / "20251030120000_timestamp_migration.py" + assert timestamp_file.exists() - sequential_file = migration_dir / "0001_timestamp_migration.py" - assert not sequential_file.exists() + sequential_file = migration_dir / "0001_timestamp_migration.py" + assert not sequential_file.exists() diff --git a/tests/integration/test_async_migrations.py b/tests/integration/test_async_migrations.py index 6c09a804..196754f6 100644 --- a/tests/integration/test_async_migrations.py +++ b/tests/integration/test_async_migrations.py @@ -1,9 +1,7 @@ """Integration tests for async migrations functionality.""" import asyncio -import tempfile from pathlib import Path -from typing import Any from unittest.mock import Mock import pytest @@ -16,123 +14,126 @@ pytestmark = pytest.mark.xdist_group("migrations") -class TestAsyncMigrationsIntegration: - """Integration tests for async migrations functionality.""" - - @pytest.fixture - def temp_migration_dir(self) -> Any: - """Create a temporary migration directory.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - migration_dir.mkdir() - yield migration_dir - - @pytest.fixture - def mock_config(self) -> Any: - """Create a mock configuration.""" - config = Mock() - config.database_url = "sqlite:///test.db" - config.bind_key = "test" - config.migration_config = {"script_location": "migrations", "version_table_name": "alembic_version"} - return config - - def test_async_migration_context_properties(self) -> None: - """Test async migration context properties.""" - context = MigrationContext(dialect="postgres") - - # Test execution mode detection - assert context.execution_mode == "sync" - - # Test metadata operations - context.set_execution_metadata("test_key", "test_value") - assert context.get_execution_metadata("test_key") == "test_value" - - def test_sync_callable_config_resolution(self) -> None: - """Test resolving synchronous callable config.""" - mock_config = Mock() - mock_config.database_url = "sqlite:///test.db" - mock_config.bind_key = "test" - mock_config.migration_config = {} - - # Create a config factory function - def get_test_config() -> Mock: - return mock_config - - async def _test() -> None: - # Mock the import_string to return our function - import sqlspec.utils.config_resolver - - original_import = sqlspec.utils.config_resolver.import_string - - try: - sqlspec.utils.config_resolver.import_string = lambda path: get_test_config - result = await resolve_config_async("test.config.get_database_config") - assert result is mock_config - finally: - sqlspec.utils.config_resolver.import_string = original_import - - run_(_test)() - - def test_async_callable_config_resolution(self) -> None: - """Test resolving asynchronous callable config.""" - mock_config = Mock() - mock_config.database_url = "sqlite:///test.db" - mock_config.bind_key = "test" - mock_config.migration_config = {} - - # Create an async config factory function - async def get_test_config() -> Mock: - return mock_config - - async def _test() -> None: - # Mock the import_string to return our async function - import sqlspec.utils.config_resolver - - original_import = sqlspec.utils.config_resolver.import_string - - try: - sqlspec.utils.config_resolver.import_string = lambda path: get_test_config - result = await resolve_config_async("test.config.async_get_database_config") - assert result is mock_config - finally: - sqlspec.utils.config_resolver.import_string = original_import - - run_(_test)() - - def test_sync_migration_runner_instantiation(self, temp_migration_dir: Any, mock_config: Any) -> None: - """Test sync migration runner instantiation.""" - - context = MigrationContext.from_config(mock_config) - runner = SyncMigrationRunner(temp_migration_dir, {}, context, {}) - - # Verify it's a sync runner - assert isinstance(runner, SyncMigrationRunner) - assert hasattr(runner, "load_migration") - assert hasattr(runner, "execute_upgrade") - - def test_async_migration_runner_instantiation(self, temp_migration_dir: Any, mock_config: Any) -> None: - """Test async migration runner instantiation.""" - from sqlspec.migrations.runner import AsyncMigrationRunner - - context = MigrationContext.from_config(mock_config) - runner = AsyncMigrationRunner(temp_migration_dir, {}, context, {}) - - # Verify it's an async runner - assert isinstance(runner, AsyncMigrationRunner) - assert hasattr(runner, "load_migration") - assert hasattr(runner, "execute_upgrade") - - # Verify methods are async - import inspect - - assert inspect.iscoroutinefunction(runner.load_migration) - assert inspect.iscoroutinefunction(runner.execute_upgrade) - - def test_async_python_migration_execution(self, temp_migration_dir: Any) -> None: - """Test execution of async Python migration.""" - # Create async Python migration file - migration_file = temp_migration_dir / "0001_create_users_async.py" - migration_content = '''"""Create users table with async validation.""" +def test_async_migration_context_properties() -> None: + """Test async migration context properties.""" + context = MigrationContext(dialect="postgres") + + # Test execution mode detection + assert context.execution_mode == "sync" + + # Test metadata operations + context.set_execution_metadata("test_key", "test_value") + assert context.get_execution_metadata("test_key") == "test_value" + + +def test_sync_callable_config_resolution() -> None: + """Test resolving synchronous callable config.""" + mock_config = Mock() + mock_config.database_url = "sqlite:///test.db" + mock_config.bind_key = "test" + mock_config.migration_config = {} + + # Create a config factory function + def get_test_config() -> Mock: + return mock_config + + async def _test() -> None: + # Mock the import_string to return our function + import sqlspec.utils.config_resolver + + original_import = sqlspec.utils.config_resolver.import_string + + try: + sqlspec.utils.config_resolver.import_string = lambda path: get_test_config + result = await resolve_config_async("test.config.get_database_config") + assert result is mock_config + finally: + sqlspec.utils.config_resolver.import_string = original_import + + run_(_test)() + + +def test_async_callable_config_resolution() -> None: + """Test resolving asynchronous callable config.""" + mock_config = Mock() + mock_config.database_url = "sqlite:///test.db" + mock_config.bind_key = "test" + mock_config.migration_config = {} + + # Create an async config factory function + async def get_test_config() -> Mock: + return mock_config + + async def _test() -> None: + # Mock the import_string to return our async function + import sqlspec.utils.config_resolver + + original_import = sqlspec.utils.config_resolver.import_string + + try: + sqlspec.utils.config_resolver.import_string = lambda path: get_test_config + result = await resolve_config_async("test.config.async_get_database_config") + assert result is mock_config + finally: + sqlspec.utils.config_resolver.import_string = original_import + + run_(_test)() + + +def test_sync_migration_runner_instantiation(tmp_path: Path) -> None: + """Test sync migration runner instantiation.""" + migration_dir = tmp_path / "migrations" + migration_dir.mkdir() + + mock_config = Mock() + mock_config.database_url = "sqlite:///test.db" + mock_config.bind_key = "test" + mock_config.migration_config = {"script_location": "migrations", "version_table_name": "alembic_version"} + + context = MigrationContext.from_config(mock_config) + runner = SyncMigrationRunner(migration_dir, {}, context, {}) + + # Verify it's a sync runner + assert isinstance(runner, SyncMigrationRunner) + assert hasattr(runner, "load_migration") + assert hasattr(runner, "execute_upgrade") + + +def test_async_migration_runner_instantiation(tmp_path: Path) -> None: + """Test async migration runner instantiation.""" + from sqlspec.migrations.runner import AsyncMigrationRunner + + migration_dir = tmp_path / "migrations" + migration_dir.mkdir() + + mock_config = Mock() + mock_config.database_url = "sqlite:///test.db" + mock_config.bind_key = "test" + mock_config.migration_config = {"script_location": "migrations", "version_table_name": "alembic_version"} + + context = MigrationContext.from_config(mock_config) + runner = AsyncMigrationRunner(migration_dir, {}, context, {}) + + # Verify it's an async runner + assert isinstance(runner, AsyncMigrationRunner) + assert hasattr(runner, "load_migration") + assert hasattr(runner, "execute_upgrade") + + # Verify methods are async + import inspect + + assert inspect.iscoroutinefunction(runner.load_migration) + assert inspect.iscoroutinefunction(runner.execute_upgrade) + + +def test_async_python_migration_execution(tmp_path: Path) -> None: + """Test execution of async Python migration.""" + migration_dir = tmp_path / "migrations" + migration_dir.mkdir() + + # Create async Python migration file + migration_file = migration_dir / "0001_create_users_async.py" + migration_content = '''"""Create users table with async validation.""" async def up(context): """Create users table.""" @@ -150,31 +151,35 @@ async def down(context): """Drop users table.""" return ["DROP TABLE users;"] ''' - migration_file.write_text(migration_content) + migration_file.write_text(migration_content) + + # Test loading the migration + from sqlspec.migrations.loaders import PythonFileLoader - # Test loading the migration - from sqlspec.migrations.loaders import PythonFileLoader + context = MigrationContext(dialect="postgres") + loader = PythonFileLoader(migration_dir, tmp_path, context) - context = MigrationContext(dialect="postgres") - loader = PythonFileLoader(temp_migration_dir, temp_migration_dir.parent, context) + # Test async execution + async def test_async_loading() -> None: + up_sql = await loader.get_up_sql(migration_file) + assert len(up_sql) == 1 + assert "CREATE TABLE users" in up_sql[0] - # Test async execution - async def test_async_loading() -> None: - up_sql = await loader.get_up_sql(migration_file) - assert len(up_sql) == 1 - assert "CREATE TABLE users" in up_sql[0] + down_sql = await loader.get_down_sql(migration_file) + assert len(down_sql) == 1 + assert "DROP TABLE users" in down_sql[0] - down_sql = await loader.get_down_sql(migration_file) - assert len(down_sql) == 1 - assert "DROP TABLE users" in down_sql[0] + asyncio.run(test_async_loading()) - asyncio.run(test_async_loading()) - def test_mixed_sync_async_migration_loading(self, temp_migration_dir: Any) -> None: - """Test loading both sync and async migrations in the same directory.""" - # Create sync migration - sync_migration = temp_migration_dir / "0001_sync_migration.py" - sync_migration.write_text(""" +def test_mixed_sync_async_migration_loading(tmp_path: Path) -> None: + """Test loading both sync and async migrations in the same directory.""" + migration_dir = tmp_path / "migrations" + migration_dir.mkdir() + + # Create sync migration + sync_migration = migration_dir / "0001_sync_migration.py" + sync_migration.write_text(""" def up(context): return ["CREATE TABLE sync_test (id INT);"] @@ -182,9 +187,9 @@ def down(context): return ["DROP TABLE sync_test;"] """) - # Create async migration - async_migration = temp_migration_dir / "0002_async_migration.py" - async_migration.write_text(""" + # Create async migration + async_migration = migration_dir / "0002_async_migration.py" + async_migration.write_text(""" async def up(context): return ["CREATE TABLE async_test (id INT);"] @@ -192,41 +197,51 @@ async def down(context): return ["DROP TABLE async_test;"] """) - context = MigrationContext(dialect="postgres") - runner = create_migration_runner(temp_migration_dir, {}, context, {}, is_async=False) + mock_config = Mock() + mock_config.database_url = "sqlite:///test.db" + mock_config.bind_key = "test" + mock_config.migration_config = {"script_location": "migrations", "version_table_name": "alembic_version"} + + context = MigrationContext(dialect="postgres") + runner = create_migration_runner(migration_dir, {}, context, {}, is_async=False) - # Get migration files - migrations = runner.get_migration_files() - assert len(migrations) == 2 + # Get migration files + migrations = runner.get_migration_files() + assert len(migrations) == 2 - # Verify both migrations are loaded - versions = [version for version, _ in migrations] - assert "0001" in versions - assert "0002" in versions + # Verify both migrations are loaded + versions = [version for version, _ in migrations] + assert "0001" in versions + assert "0002" in versions - def test_migration_context_validation(self) -> None: - """Test migration context async usage validation.""" - context = MigrationContext() - # Test with sync function - def sync_migration() -> list[str]: - return ["CREATE TABLE test (id INT);"] +def test_migration_context_validation() -> None: + """Test migration context async usage validation.""" + context = MigrationContext() - # Should not raise any exceptions - context.validate_async_usage(sync_migration) + # Test with sync function + def sync_migration() -> list[str]: + return ["CREATE TABLE test (id INT);"] - # Test with async function - async def async_migration() -> list[str]: - return ["CREATE TABLE test (id INT);"] + # Should not raise any exceptions + context.validate_async_usage(sync_migration) - # Should handle async function validation - context.validate_async_usage(async_migration) + # Test with async function + async def async_migration() -> list[str]: + return ["CREATE TABLE test (id INT);"] - def test_error_handling_in_async_migrations(self, temp_migration_dir: Any) -> None: - """Test error handling in async migration execution.""" - # Create migration with error - error_migration = temp_migration_dir / "0001_error_migration.py" - error_migration.write_text(""" + # Should handle async function validation + context.validate_async_usage(async_migration) + + +def test_error_handling_in_async_migrations(tmp_path: Path) -> None: + """Test error handling in async migration execution.""" + migration_dir = tmp_path / "migrations" + migration_dir.mkdir() + + # Create migration with error + error_migration = migration_dir / "0001_error_migration.py" + error_migration.write_text(""" async def up(context): raise ValueError("Test error in migration") @@ -234,47 +249,48 @@ def down(context): return ["DROP TABLE test;"] """) - from sqlspec.migrations.loaders import PythonFileLoader + from sqlspec.migrations.loaders import PythonFileLoader + + context = MigrationContext(dialect="postgres") + loader = PythonFileLoader(migration_dir, tmp_path, context) - context = MigrationContext(dialect="postgres") - loader = PythonFileLoader(temp_migration_dir, temp_migration_dir.parent, context) + # Test that error is properly raised + async def test_error_handling() -> None: + with pytest.raises(Exception): # Should raise the ValueError from migration + await loader.get_up_sql(error_migration) - # Test that error is properly raised - async def test_error_handling() -> None: - with pytest.raises(Exception): # Should raise the ValueError from migration - await loader.get_up_sql(error_migration) + asyncio.run(test_error_handling()) - asyncio.run(test_error_handling()) - def test_config_resolver_with_list_configs(self) -> None: - """Test config resolver with list of configurations.""" - mock_config1 = Mock() - mock_config1.database_url = "sqlite:///test1.db" - mock_config1.bind_key = "test1" - mock_config1.migration_config = {} +def test_config_resolver_with_list_configs() -> None: + """Test config resolver with list of configurations.""" + mock_config1 = Mock() + mock_config1.database_url = "sqlite:///test1.db" + mock_config1.bind_key = "test1" + mock_config1.migration_config = {} - mock_config2 = Mock() - mock_config2.database_url = "sqlite:///test2.db" - mock_config2.bind_key = "test2" - mock_config2.migration_config = {} + mock_config2 = Mock() + mock_config2.database_url = "sqlite:///test2.db" + mock_config2.bind_key = "test2" + mock_config2.migration_config = {} - def get_configs() -> list[Mock]: - return [mock_config1, mock_config2] + def get_configs() -> list[Mock]: + return [mock_config1, mock_config2] - async def _test() -> None: - # Mock the import_string to return our function - import sqlspec.utils.config_resolver + async def _test() -> None: + # Mock the import_string to return our function + import sqlspec.utils.config_resolver - original_import = sqlspec.utils.config_resolver.import_string + original_import = sqlspec.utils.config_resolver.import_string - try: - sqlspec.utils.config_resolver.import_string = lambda path: get_configs - result = await resolve_config_async("test.config.get_database_configs") - assert isinstance(result, list) - assert len(result) == 2 - assert result[0] is mock_config1 - assert result[1] is mock_config2 - finally: - sqlspec.utils.config_resolver.import_string = original_import + try: + sqlspec.utils.config_resolver.import_string = lambda path: get_configs + result = await resolve_config_async("test.config.get_database_configs") + assert isinstance(result, list) + assert len(result) == 2 + assert result[0] is mock_config1 + assert result[1] is mock_config2 + finally: + sqlspec.utils.config_resolver.import_string = original_import - run_(_test)() + run_(_test)() diff --git a/tests/integration/test_cli/test_sync_adapter_cli.py b/tests/integration/test_cli/test_sync_adapter_cli.py index 0d71765a..219a02e7 100644 --- a/tests/integration/test_cli/test_sync_adapter_cli.py +++ b/tests/integration/test_cli/test_sync_adapter_cli.py @@ -9,10 +9,7 @@ The tests use the full CLI workflow: init -> create-migration -> upgrade -> downgrade """ -import os -import shutil import sys -import tempfile import uuid from collections.abc import Generator from pathlib import Path @@ -26,16 +23,10 @@ @pytest.fixture -def temp_project_dir() -> Generator[Path, None, None]: +def temp_project_dir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Generator[Path, None, None]: """Create a temporary project directory with cleanup.""" - temp_dir = Path(tempfile.mkdtemp()) - original_dir = os.getcwd() - os.chdir(temp_dir) - - yield temp_dir - - os.chdir(original_dir) - shutil.rmtree(temp_dir, ignore_errors=True) + monkeypatch.chdir(tmp_path) + yield tmp_path @pytest.fixture diff --git a/tests/integration/test_dishka/test_dishka_integration.py b/tests/integration/test_dishka/test_dishka_integration.py index 13061907..d4be70f0 100644 --- a/tests/integration/test_dishka/test_dishka_integration.py +++ b/tests/integration/test_dishka/test_dishka_integration.py @@ -1,9 +1,6 @@ """Integration tests for Dishka DI framework with SQLSpec CLI.""" -import os -import tempfile from pathlib import Path -from typing import Any import pytest from click.testing import CliRunner @@ -15,16 +12,13 @@ pytestmark = pytest.mark.xdist_group("dishka") -def test_simple_sync_dishka_provider(simple_sqlite_provider: Any) -> None: +def test_simple_sync_dishka_provider(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Test CLI with a simple synchronous Dishka provider.""" runner = CliRunner() - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - # Create a module that uses Dishka container - config_module = ''' + monkeypatch.chdir(tmp_path) + + config_module = ''' from dishka import make_container from tests.integration.test_dishka.conftest import simple_sqlite_provider @@ -48,29 +42,24 @@ def get_database_config(self) -> SqliteConfig: with container() as request_container: return request_container.get(SqliteConfig) ''' - Path("dishka_config.py").write_text(config_module) - - result = runner.invoke( - add_migration_commands(), ["--config", "dishka_config.get_config_from_dishka", "show-config"] - ) + Path("dishka_config.py").write_text(config_module) - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), ["--config", "dishka_config.get_config_from_dishka", "show-config"] + ) - assert result.exit_code == 0 - assert "dishka_sqlite" in result.output - assert "Migration Enabled" in result.output or "migrations enabled" in result.output + assert result.exit_code == 0 + assert "dishka_sqlite" in result.output + assert "Migration Enabled" in result.output or "migrations enabled" in result.output -def test_async_dishka_provider() -> None: +def test_async_dishka_provider(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Test CLI with an asynchronous Dishka provider.""" runner = CliRunner() - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = ''' + monkeypatch.chdir(tmp_path) + + config_module = ''' import asyncio from dishka import make_async_container, Provider, provide, Scope from sqlspec.adapters.sqlite.config import SqliteConfig @@ -92,30 +81,24 @@ async def get_async_config_from_dishka(): async with container() as request_container: return await request_container.get(SqliteConfig) ''' - Path("async_dishka_config.py").write_text(config_module) + Path("async_dishka_config.py").write_text(config_module) - result = runner.invoke( - add_migration_commands(), - ["--config", "async_dishka_config.get_async_config_from_dishka", "show-config"], - ) - - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), ["--config", "async_dishka_config.get_async_config_from_dishka", "show-config"] + ) - assert result.exit_code == 0 - assert "async_dishka_sqlite" in result.output - assert "Migration Enabled" in result.output or "migrations enabled" in result.output + assert result.exit_code == 0 + assert "async_dishka_sqlite" in result.output + assert "Migration Enabled" in result.output or "migrations enabled" in result.output -def test_multi_config_dishka_provider() -> None: +def test_multi_config_dishka_provider(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Test CLI with Dishka provider returning multiple configs.""" runner = CliRunner() - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = ''' + monkeypatch.chdir(tmp_path) + + config_module = ''' from dishka import make_container, Provider, provide, Scope from sqlspec.adapters.sqlite import SqliteConfig from sqlspec.adapters.duckdb import DuckDBConfig @@ -145,31 +128,25 @@ def get_multi_configs_from_dishka(): duckdb_config = request_container.get(DuckDBConfig) return [sqlite_config, duckdb_config] ''' - Path("multi_dishka_config.py").write_text(config_module) + Path("multi_dishka_config.py").write_text(config_module) - result = runner.invoke( - add_migration_commands(), - ["--config", "multi_dishka_config.get_multi_configs_from_dishka", "show-config"], - ) + result = runner.invoke( + add_migration_commands(), ["--config", "multi_dishka_config.get_multi_configs_from_dishka", "show-config"] + ) - finally: - os.chdir(original_dir) + assert result.exit_code == 0 + assert "dishka_multi_sqlite" in result.output + assert "dishka_multi_duckdb" in result.output + assert "2 configuration(s)" in result.output - assert result.exit_code == 0 - assert "dishka_multi_sqlite" in result.output - assert "dishka_multi_duckdb" in result.output - assert "2 configuration(s)" in result.output - -def test_async_multi_config_dishka_provider() -> None: +def test_async_multi_config_dishka_provider(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Test CLI with async Dishka provider returning multiple configs.""" runner = CliRunner() - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = ''' + monkeypatch.chdir(tmp_path) + + config_module = ''' import asyncio from dishka import make_async_container, Provider, provide, Scope from sqlspec.adapters.sqlite.config import SqliteConfig @@ -213,32 +190,27 @@ async def get_async_multi_configs_from_dishka(): duckdb_config = await request_container.get(DuckDBConfig) return [sqlite_config, aiosqlite_config, duckdb_config] ''' - Path("async_multi_dishka_config.py").write_text(config_module) - - result = runner.invoke( - add_migration_commands(), - ["--config", "async_multi_dishka_config.get_async_multi_configs_from_dishka", "show-config"], - ) + Path("async_multi_dishka_config.py").write_text(config_module) - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), + ["--config", "async_multi_dishka_config.get_async_multi_configs_from_dishka", "show-config"], + ) - assert result.exit_code == 0 - assert "async_multi_sqlite" in result.output - assert "async_multi_aiosqlite" in result.output - assert "async_multi_duckdb" in result.output - assert "3 configuration(s)" in result.output + assert result.exit_code == 0 + assert "async_multi_sqlite" in result.output + assert "async_multi_aiosqlite" in result.output + assert "async_multi_duckdb" in result.output + assert "3 configuration(s)" in result.output -def test_dishka_provider_with_dependencies() -> None: +def test_dishka_provider_with_dependencies(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Test Dishka provider that has complex dependencies.""" runner = CliRunner() - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = ''' + monkeypatch.chdir(tmp_path) + + config_module = ''' from dishka import make_container, Provider, provide, Scope from sqlspec.adapters.sqlite.config import SqliteConfig @@ -266,30 +238,24 @@ def get_complex_config_from_dishka(): with container() as request_container: return request_container.get(SqliteConfig) ''' - Path("complex_dishka_config.py").write_text(config_module) - - result = runner.invoke( - add_migration_commands(), - ["--config", "complex_dishka_config.get_complex_config_from_dishka", "show-config"], - ) + Path("complex_dishka_config.py").write_text(config_module) - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), ["--config", "complex_dishka_config.get_complex_config_from_dishka", "show-config"] + ) - assert result.exit_code == 0 - assert "complex_dishka" in result.output - assert "complex_migrations" in result.output + assert result.exit_code == 0 + assert "complex_dishka" in result.output + assert "complex_migrations" in result.output -def test_dishka_error_handling() -> None: +def test_dishka_error_handling(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Test proper error handling when Dishka container fails.""" runner = CliRunner() - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = ''' + monkeypatch.chdir(tmp_path) + + config_module = ''' from dishka import make_container, Provider from sqlspec.adapters.sqlite.config import SqliteConfig @@ -303,29 +269,24 @@ def get_failing_dishka_config(): # This should raise an exception return request_container.get(SqliteConfig) ''' - Path("failing_dishka_config.py").write_text(config_module) - - result = runner.invoke( - add_migration_commands(), ["--config", "failing_dishka_config.get_failing_dishka_config", "show-config"] - ) + Path("failing_dishka_config.py").write_text(config_module) - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), ["--config", "failing_dishka_config.get_failing_dishka_config", "show-config"] + ) - assert result.exit_code == 1 - assert "Error loading config" in result.output - assert "Failed to execute callable config" in result.output + assert result.exit_code == 1 + assert "Error loading config" in result.output + assert "Failed to execute callable config" in result.output -def test_dishka_async_with_migration_commands() -> None: +def test_dishka_async_with_migration_commands(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Test that migration commands work with async Dishka configs.""" runner = CliRunner() - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = ''' + monkeypatch.chdir(tmp_path) + + config_module = ''' import asyncio from dishka import make_async_container, Provider, provide, Scope from sqlspec.adapters.sqlite.config import SqliteConfig @@ -349,31 +310,25 @@ async def get_migration_config_from_dishka(): async with container() as request_container: return await request_container.get(SqliteConfig) ''' - Path("migration_dishka_config.py").write_text(config_module) + Path("migration_dishka_config.py").write_text(config_module) - # Test that the config loads properly for migration commands - result = runner.invoke( - add_migration_commands(), - ["--config", "migration_dishka_config.get_migration_config_from_dishka", "show-config"], - ) - - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), + ["--config", "migration_dishka_config.get_migration_config_from_dishka", "show-config"], + ) - assert result.exit_code == 0 - assert "migration_dishka" in result.output - assert "dishka_migrations" in result.output or "Migration Enabled" in result.output + assert result.exit_code == 0 + assert "migration_dishka" in result.output + assert "dishka_migrations" in result.output or "Migration Enabled" in result.output -def test_dishka_with_config_validation() -> None: +def test_dishka_with_config_validation(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Test Dishka integration with config validation enabled.""" runner = CliRunner() - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = ''' + monkeypatch.chdir(tmp_path) + + config_module = ''' import asyncio from dishka import make_async_container, Provider, provide, Scope from sqlspec.adapters.duckdb.config import DuckDBConfig @@ -394,40 +349,29 @@ async def get_validated_config_from_dishka(): async with container() as request_container: return await request_container.get(DuckDBConfig) ''' - Path("validated_dishka_config.py").write_text(config_module) - - result = runner.invoke( - add_migration_commands(), - [ - "--config", - "validated_dishka_config.get_validated_config_from_dishka", - "--validate-config", - "show-config", - ], - ) + Path("validated_dishka_config.py").write_text(config_module) - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), + ["--config", "validated_dishka_config.get_validated_config_from_dishka", "--validate-config", "show-config"], + ) - assert result.exit_code == 0 - assert "Successfully loaded 1 config(s)" in result.output - assert "validated_dishka" in result.output + assert result.exit_code == 0 + assert "Successfully loaded 1 config(s)" in result.output + assert "validated_dishka" in result.output -def test_real_world_dishka_scenario() -> None: +def test_real_world_dishka_scenario(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Test a real-world scenario mimicking the user's issue.""" runner = CliRunner() - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - # Create a structure similar to user's litestar_dishka_modular.sqlspec_main.main - Path("litestar_dishka_modular").mkdir() - Path("litestar_dishka_modular/__init__.py").write_text("") - Path("litestar_dishka_modular/sqlspec_main.py").write_text("") + monkeypatch.chdir(tmp_path) + + Path("litestar_dishka_modular").mkdir() + Path("litestar_dishka_modular/__init__.py").write_text("") + Path("litestar_dishka_modular/sqlspec_main.py").write_text("") - config_module = ''' + config_module = ''' """Simulates the user's actual Dishka configuration.""" import asyncio from typing import List @@ -472,31 +416,25 @@ async def main() -> List: analytics_config = await request_container.get(DuckDBConfig) return [primary_config, analytics_config] ''' - Path("litestar_dishka_modular/sqlspec_main.py").write_text(config_module) - - # Test the exact command that was failing for the user - result = runner.invoke( - add_migration_commands(), ["--config", "litestar_dishka_modular.sqlspec_main.main", "show-config"] - ) + Path("litestar_dishka_modular/sqlspec_main.py").write_text(config_module) - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), ["--config", "litestar_dishka_modular.sqlspec_main.main", "show-config"] + ) - assert result.exit_code == 0 - assert "primary_db" in result.output - assert "analytics_db" in result.output - assert "2 configuration(s)" in result.output + assert result.exit_code == 0 + assert "primary_db" in result.output + assert "analytics_db" in result.output + assert "2 configuration(s)" in result.output -def test_dishka_provider_cleanup() -> None: +def test_dishka_provider_cleanup(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Test that Dishka providers are properly cleaned up.""" runner = CliRunner() - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = ''' + monkeypatch.chdir(tmp_path) + + config_module = ''' import asyncio from dishka import make_async_container, Provider, provide, Scope from sqlspec.adapters.sqlite.config import SqliteConfig @@ -524,15 +462,11 @@ async def get_cleanup_config(): config = await request_container.get(SqliteConfig) return config ''' - Path("cleanup_dishka_config.py").write_text(config_module) - - result = runner.invoke( - add_migration_commands(), ["--config", "cleanup_dishka_config.get_cleanup_config", "show-config"] - ) + Path("cleanup_dishka_config.py").write_text(config_module) - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), ["--config", "cleanup_dishka_config.get_cleanup_config", "show-config"] + ) - assert result.exit_code == 0 - assert "cleanup_test" in result.output - # The container should have been cleaned up without errors + assert result.exit_code == 0 + assert "cleanup_test" in result.output diff --git a/tests/integration/test_loader/test_file_system_loading.py b/tests/integration/test_loader/test_file_system_loading.py index eab7fb78..cbcae806 100644 --- a/tests/integration/test_loader/test_file_system_loading.py +++ b/tests/integration/test_loader/test_file_system_loading.py @@ -9,9 +9,7 @@ """ import os -import tempfile import time -from collections.abc import Generator from pathlib import Path from unittest.mock import Mock, patch @@ -22,21 +20,13 @@ from sqlspec.loader import SQLFileLoader -@pytest.fixture -def temp_workspace() -> Generator[Path, None, None]: - """Create a temporary workspace for file system tests.""" - with tempfile.TemporaryDirectory() as temp_dir: - workspace = Path(temp_dir) - yield workspace - - -def test_load_single_file_from_filesystem(temp_workspace: Path) -> None: +def test_load_single_file_from_filesystem(tmp_path: Path) -> None: """Test loading a single SQL file from the file system. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - sql_file = temp_workspace / "test_queries.sql" + sql_file = tmp_path / "test_queries.sql" sql_file.write_text(""" -- name: get_user_count SELECT COUNT(*) as total_users FROM users; @@ -57,13 +47,13 @@ def test_load_single_file_from_filesystem(temp_workspace: Path) -> None: assert "COUNT(*)" in user_count_sql.sql -def test_load_multiple_files_from_filesystem(temp_workspace: Path) -> None: +def test_load_multiple_files_from_filesystem(tmp_path: Path) -> None: """Test loading multiple SQL files from the file system. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - users_file = temp_workspace / "users.sql" + users_file = tmp_path / "users.sql" users_file.write_text(""" -- name: create_user INSERT INTO users (name, email) VALUES (:name, :email); @@ -72,7 +62,7 @@ def test_load_multiple_files_from_filesystem(temp_workspace: Path) -> None: UPDATE users SET email = :email WHERE id = :user_id; """) - products_file = temp_workspace / "products.sql" + products_file = tmp_path / "products.sql" products_file.write_text(""" -- name: list_products SELECT id, name, price FROM products ORDER BY name; @@ -95,13 +85,13 @@ def test_load_multiple_files_from_filesystem(temp_workspace: Path) -> None: assert str(products_file) in files -def test_load_directory_structure_from_filesystem(temp_workspace: Path) -> None: +def test_load_directory_structure_from_filesystem(tmp_path: Path) -> None: """Test loading entire directory structures from file system. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - queries_dir = temp_workspace / "queries" + queries_dir = tmp_path / "queries" queries_dir.mkdir() analytics_dir = queries_dir / "analytics" @@ -110,7 +100,7 @@ def test_load_directory_structure_from_filesystem(temp_workspace: Path) -> None: admin_dir = queries_dir / "admin" admin_dir.mkdir() - (temp_workspace / "root.sql").write_text(""" + (tmp_path / "root.sql").write_text(""" -- name: health_check SELECT 'OK' as status; """) @@ -134,7 +124,7 @@ def test_load_directory_structure_from_filesystem(temp_workspace: Path) -> None: """) loader = SQLFileLoader() - loader.load_sql(temp_workspace) + loader.load_sql(tmp_path) queries = loader.list_queries() @@ -146,13 +136,13 @@ def test_load_directory_structure_from_filesystem(temp_workspace: Path) -> None: assert "queries.admin.cleanup_old_logs" in queries -def test_file_content_encoding_handling(temp_workspace: Path) -> None: +def test_file_content_encoding_handling(tmp_path: Path) -> None: """Test handling of different file encodings. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - utf8_file = temp_workspace / "utf8_queries.sql" + utf8_file = tmp_path / "utf8_queries.sql" utf8_content = """ -- name: unicode_query -- Test with Unicode: 测试 файл עברית @@ -170,13 +160,13 @@ def test_file_content_encoding_handling(temp_workspace: Path) -> None: assert isinstance(sql, SQL) -def test_file_modification_detection(temp_workspace: Path) -> None: +def test_file_modification_detection(tmp_path: Path) -> None: """Test detection of file modifications. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - sql_file = temp_workspace / "modifiable.sql" + sql_file = tmp_path / "modifiable.sql" original_content = """ -- name: original_query SELECT 'original' as version; @@ -208,19 +198,19 @@ def test_file_modification_detection(temp_workspace: Path) -> None: assert "original_query" not in queries -def test_symlink_resolution(temp_workspace: Path) -> None: +def test_symlink_resolution(tmp_path: Path) -> None: """Test resolution of symbolic links. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - original_file = temp_workspace / "original.sql" + original_file = tmp_path / "original.sql" original_file.write_text(""" -- name: symlinked_query SELECT 'from symlink' as source; """) - symlink_file = temp_workspace / "linked.sql" + symlink_file = tmp_path / "linked.sql" try: symlink_file.symlink_to(original_file) except OSError: @@ -233,30 +223,30 @@ def test_symlink_resolution(temp_workspace: Path) -> None: assert "symlinked_query" in queries -def test_nonexistent_file_error(temp_workspace: Path) -> None: +def test_nonexistent_file_error(tmp_path: Path) -> None: """Test error handling for nonexistent files. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. Raises: SQLFileNotFoundError: When attempting to load nonexistent file. """ loader = SQLFileLoader() - nonexistent_file = temp_workspace / "does_not_exist.sql" + nonexistent_file = tmp_path / "does_not_exist.sql" with pytest.raises(SQLFileNotFoundError): loader.load_sql(nonexistent_file) -def test_nonexistent_directory_handling(temp_workspace: Path) -> None: +def test_nonexistent_directory_handling(tmp_path: Path) -> None: """Test handling of nonexistent directories. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ loader = SQLFileLoader() - nonexistent_dir = temp_workspace / "does_not_exist" + nonexistent_dir = tmp_path / "does_not_exist" loader.load_sql(nonexistent_dir) @@ -264,11 +254,11 @@ def test_nonexistent_directory_handling(temp_workspace: Path) -> None: assert loader.list_files() == [] -def test_permission_denied_error(temp_workspace: Path) -> None: +def test_permission_denied_error(tmp_path: Path) -> None: """Test handling of permission denied errors. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. Raises: SQLFileParseError: When file permissions prevent reading. @@ -276,7 +266,7 @@ def test_permission_denied_error(temp_workspace: Path) -> None: if os.name == "nt": pytest.skip("Permission testing not reliable on Windows") - restricted_file = temp_workspace / "restricted.sql" + restricted_file = tmp_path / "restricted.sql" restricted_file.write_text(""" -- name: restricted_query SELECT 'restricted' as access; @@ -293,13 +283,13 @@ def test_permission_denied_error(temp_workspace: Path) -> None: restricted_file.chmod(0o644) -def test_corrupted_file_handling(temp_workspace: Path) -> None: +def test_corrupted_file_handling(tmp_path: Path) -> None: """Test handling of SQL files without named statements. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - corrupted_file = temp_workspace / "corrupted.sql" + corrupted_file = tmp_path / "corrupted.sql" corrupted_file.write_text(""" This is not a valid SQL file with named queries. @@ -315,13 +305,13 @@ def test_corrupted_file_handling(temp_workspace: Path) -> None: assert str(corrupted_file) not in loader._files # pyright: ignore -def test_empty_file_handling(temp_workspace: Path) -> None: +def test_empty_file_handling(tmp_path: Path) -> None: """Test handling of empty files. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - empty_file = temp_workspace / "empty.sql" + empty_file = tmp_path / "empty.sql" empty_file.write_text("") loader = SQLFileLoader() @@ -332,16 +322,16 @@ def test_empty_file_handling(temp_workspace: Path) -> None: assert str(empty_file) not in loader._files # pyright: ignore -def test_binary_file_handling(temp_workspace: Path) -> None: +def test_binary_file_handling(tmp_path: Path) -> None: """Test handling of binary files with .sql extension. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. Raises: SQLFileParseError: When file contains binary data that can't be decoded. """ - binary_file = temp_workspace / "binary.sql" + binary_file = tmp_path / "binary.sql" Path(binary_file).write_bytes(b"\xff\xfe\xfd\xfc") @@ -351,13 +341,13 @@ def test_binary_file_handling(temp_workspace: Path) -> None: loader.load_sql(binary_file) -def test_large_file_loading_performance(temp_workspace: Path) -> None: +def test_large_file_loading_performance(tmp_path: Path) -> None: """Test performance with large SQL files. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - large_file = temp_workspace / "large_queries.sql" + large_file = tmp_path / "large_queries.sql" large_content = "\n".join( f""" @@ -390,13 +380,13 @@ def test_large_file_loading_performance(temp_workspace: Path) -> None: assert load_time < 5.0, f"Loading took too long: {load_time:.2f}s" -def test_many_small_files_performance(temp_workspace: Path) -> None: +def test_many_small_files_performance(tmp_path: Path) -> None: """Test performance with many small SQL files. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - files_dir = temp_workspace / "many_files" + files_dir = tmp_path / "many_files" files_dir.mkdir() for i in range(100): @@ -420,13 +410,13 @@ def test_many_small_files_performance(temp_workspace: Path) -> None: assert load_time < 10.0, f"Loading took too long: {load_time:.2f}s" -def test_deep_directory_structure_performance(temp_workspace: Path) -> None: +def test_deep_directory_structure_performance(tmp_path: Path) -> None: """Test performance with deep directory structures. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - current_path = temp_workspace + current_path = tmp_path for level in range(10): current_path = current_path / f"level_{level}" current_path.mkdir() @@ -440,7 +430,7 @@ def test_deep_directory_structure_performance(temp_workspace: Path) -> None: loader = SQLFileLoader() start_time = time.time() - loader.load_sql(temp_workspace) + loader.load_sql(tmp_path) end_time = time.time() load_time = end_time - start_time @@ -455,13 +445,13 @@ def test_deep_directory_structure_performance(temp_workspace: Path) -> None: assert load_time < 5.0, f"Loading took too long: {load_time:.2f}s" -def test_concurrent_file_modification(temp_workspace: Path) -> None: +def test_concurrent_file_modification(tmp_path: Path) -> None: """Test handling of concurrent file modifications. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - shared_file = temp_workspace / "shared.sql" + shared_file = tmp_path / "shared.sql" shared_file.write_text(""" -- name: shared_query_v1 @@ -492,13 +482,13 @@ def test_concurrent_file_modification(temp_workspace: Path) -> None: assert "shared_query_v2" not in loader2.list_queries() -def test_multiple_loaders_same_file(temp_workspace: Path) -> None: +def test_multiple_loaders_same_file(tmp_path: Path) -> None: """Test multiple loaders accessing the same file. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - sql_file = temp_workspace / "multi_access.sql" + sql_file = tmp_path / "multi_access.sql" sql_file.write_text(""" -- name: multi_access_query SELECT 'accessed by multiple loaders' as message; @@ -517,14 +507,14 @@ def test_multiple_loaders_same_file(temp_workspace: Path) -> None: assert isinstance(sql, SQL) -def test_loader_isolation(temp_workspace: Path) -> None: +def test_loader_isolation(tmp_path: Path) -> None: """Test that loaders are properly isolated from each other. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - file1 = temp_workspace / "loader1.sql" - file2 = temp_workspace / "loader2.sql" + file1 = tmp_path / "loader1.sql" + file2 = tmp_path / "loader2.sql" file1.write_text(""" -- name: loader1_query @@ -552,14 +542,14 @@ def test_loader_isolation(temp_workspace: Path) -> None: assert "loader2_query" not in queries1 -def test_file_cache_persistence_across_loaders(temp_workspace: Path) -> None: +def test_file_cache_persistence_across_loaders(tmp_path: Path) -> None: """Test that file cache persists across different loader instances. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - sql_file = temp_workspace / "cached.sql" + sql_file = tmp_path / "cached.sql" sql_file.write_text(""" -- name: cached_query SELECT 'cached content' as status; @@ -586,14 +576,14 @@ def test_file_cache_persistence_across_loaders(temp_workspace: Path) -> None: assert cache_load_time < 1.0 -def test_cache_invalidation_on_file_change(temp_workspace: Path) -> None: +def test_cache_invalidation_on_file_change(tmp_path: Path) -> None: """Test cache invalidation when files change. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - sql_file = temp_workspace / "changing.sql" + sql_file = tmp_path / "changing.sql" original_content = """ -- name: changing_query_v1 @@ -626,16 +616,16 @@ def test_cache_invalidation_on_file_change(temp_workspace: Path) -> None: assert "changing_query_v1" not in queries -def test_cache_behavior_with_file_deletion(temp_workspace: Path) -> None: +def test_cache_behavior_with_file_deletion(tmp_path: Path) -> None: """Test cache behavior when cached files are deleted. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. Raises: SQLFileNotFoundError: When attempting to load deleted file. """ - sql_file = temp_workspace / "deletable.sql" + sql_file = tmp_path / "deletable.sql" sql_file.write_text(""" -- name: deletable_query SELECT 'will be deleted' as status; @@ -656,14 +646,14 @@ def test_cache_behavior_with_file_deletion(temp_workspace: Path) -> None: assert "deletable_query" in loader.list_queries() -def test_unicode_file_names(temp_workspace: Path) -> None: +def test_unicode_file_names(tmp_path: Path) -> None: """Test handling of Unicode file names. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ try: - unicode_file = temp_workspace / "测试_файл_test.sql" + unicode_file = tmp_path / "测试_файл_test.sql" unicode_file.write_text( """ -- name: unicode_filename_query @@ -681,13 +671,13 @@ def test_unicode_file_names(temp_workspace: Path) -> None: assert "unicode_filename_query" in queries -def test_unicode_file_content(temp_workspace: Path) -> None: +def test_unicode_file_content(tmp_path: Path) -> None: """Test handling of Unicode content in files. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - unicode_file = temp_workspace / "unicode_content.sql" + unicode_file = tmp_path / "unicode_content.sql" unicode_content = """ -- name: unicode_content_query @@ -708,13 +698,13 @@ def test_unicode_file_content(temp_workspace: Path) -> None: assert "Unicode: 测试 тест עברית" in sql.sql -def test_mixed_encoding_handling(temp_workspace: Path) -> None: +def test_mixed_encoding_handling(tmp_path: Path) -> None: """Test handling of different encodings. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ - utf8_file = temp_workspace / "utf8.sql" + utf8_file = tmp_path / "utf8.sql" utf8_file.write_text( """ -- name: utf8_query @@ -723,7 +713,7 @@ def test_mixed_encoding_handling(temp_workspace: Path) -> None: encoding="utf-8", ) - latin1_file = temp_workspace / "latin1.sql" + latin1_file = tmp_path / "latin1.sql" latin1_content = """ -- name: latin1_query SELECT 'Latin-1: café' as message; @@ -741,14 +731,14 @@ def test_mixed_encoding_handling(temp_workspace: Path) -> None: assert "latin1_query" in latin1_loader.list_queries() -def test_special_characters_in_paths(temp_workspace: Path) -> None: +def test_special_characters_in_paths(tmp_path: Path) -> None: """Test handling of special characters in file paths. Args: - temp_workspace: Temporary directory for test files. + tmp_path: Temporary directory for test files. """ try: - special_dir = temp_workspace / "special-chars_&_symbols!@#$" + special_dir = tmp_path / "special-chars_&_symbols!@#$" special_dir.mkdir() special_file = special_dir / "query-file_with&symbols.sql" diff --git a/tests/unit/test_cli/test_config_loading.py b/tests/unit/test_cli/test_config_loading.py index d440a304..d03e341e 100644 --- a/tests/unit/test_cli/test_config_loading.py +++ b/tests/unit/test_cli/test_config_loading.py @@ -1,8 +1,6 @@ """Tests for CLI configuration loading functionality.""" -import os import sys -import tempfile import uuid from collections.abc import Iterator from pathlib import Path @@ -34,13 +32,14 @@ def _create_module(path: "Path", content: str) -> str: return module_name -def test_direct_config_instance_loading(cleanup_test_modules: None) -> None: +def test_direct_config_instance_loading( + tmp_path: Path, cleanup_test_modules: None, monkeypatch: pytest.MonkeyPatch +) -> None: """Test loading a direct config instance through CLI.""" runner = CliRunner() - with tempfile.TemporaryDirectory() as temp_dir: - # Create a test module with a direct config instance - config_module = """ + # Create a test module with a direct config instance + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig config = SqliteConfig( @@ -50,30 +49,25 @@ def test_direct_config_instance_loading(cleanup_test_modules: None) -> None: ) database_config = config """ - module_name = _create_module(Path(temp_dir), config_module) + module_name = _create_module(tmp_path, config_module) - # Change to the temp directory - original_cwd = os.getcwd() - try: - os.chdir(temp_dir) - result = runner.invoke( - add_migration_commands(), ["--config", f"{module_name}.database_config", "show-config"] - ) - finally: - os.chdir(original_cwd) + # Change to the temp directory + monkeypatch.chdir(tmp_path) + result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.database_config", "show-config"]) - assert result.exit_code == 0 - assert "test" in result.output - assert "Migration Enabled" in result.output or "migrations enabled" in result.output + assert result.exit_code == 0 + assert "test" in result.output + assert "Migration Enabled" in result.output or "migrations enabled" in result.output -def test_sync_callable_config_loading(cleanup_test_modules: None) -> None: +def test_sync_callable_config_loading( + tmp_path: Path, cleanup_test_modules: None, monkeypatch: pytest.MonkeyPatch +) -> None: """Test loading config from synchronous callable through CLI.""" runner = CliRunner() - with tempfile.TemporaryDirectory() as temp_dir: - # Create a test module with sync callable - config_module = """ + # Create a test module with sync callable + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_database_config(): @@ -84,30 +78,25 @@ def get_database_config(): ) return config """ - module_name = _create_module(Path(temp_dir), config_module) + module_name = _create_module(tmp_path, config_module) - # Change to the temp directory - original_cwd = os.getcwd() - try: - os.chdir(temp_dir) - result = runner.invoke( - add_migration_commands(), ["--config", f"{module_name}.get_database_config", "show-config"] - ) - finally: - os.chdir(original_cwd) + # Change to the temp directory + monkeypatch.chdir(tmp_path) + result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_database_config", "show-config"]) - assert result.exit_code == 0 - assert "sync_test" in result.output - assert "Migration Enabled" in result.output or "migrations enabled" in result.output + assert result.exit_code == 0 + assert "sync_test" in result.output + assert "Migration Enabled" in result.output or "migrations enabled" in result.output -def test_async_callable_config_loading(cleanup_test_modules: None) -> None: +def test_async_callable_config_loading( + tmp_path: Path, cleanup_test_modules: None, monkeypatch: pytest.MonkeyPatch +) -> None: """Test loading config from asynchronous callable through CLI.""" runner = CliRunner() - with tempfile.TemporaryDirectory() as temp_dir: - # Create a test module with async callable - config_module = """ + # Create a test module with async callable + config_module = """ import asyncio from sqlspec.adapters.sqlite.config import SqliteConfig @@ -121,32 +110,27 @@ async def get_database_config(): ) return config """ - module_name = _create_module(Path(temp_dir), config_module) - - # Change to the temp directory - original_cwd = os.getcwd() - try: - os.chdir(temp_dir) - result = runner.invoke( - add_migration_commands(), ["--config", f"{module_name}.get_database_config", "show-config"] - ) - finally: - os.chdir(original_cwd) - - if result.exception: - pass - assert result.exit_code == 0 - assert "async_test" in result.output - assert "Migration Enabled" in result.output or "migrations enabled" in result.output - - -def test_show_config_with_path_object(cleanup_test_modules: None) -> None: + module_name = _create_module(tmp_path, config_module) + + # Change to the temp directory + monkeypatch.chdir(tmp_path) + result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_database_config", "show-config"]) + + if result.exception: + pass + assert result.exit_code == 0 + assert "async_test" in result.output + assert "Migration Enabled" in result.output or "migrations enabled" in result.output + + +def test_show_config_with_path_object( + tmp_path: Path, cleanup_test_modules: None, monkeypatch: pytest.MonkeyPatch +) -> None: """Test show-config handles Path objects in script_location without crashing.""" runner = CliRunner() - with tempfile.TemporaryDirectory() as temp_dir: - # Create a test module with Path object in script_location - config_module = """ + # Create a test module with Path object in script_location + config_module = """ from pathlib import Path from sqlspec.adapters.sqlite.config import SqliteConfig @@ -157,19 +141,13 @@ def test_show_config_with_path_object(cleanup_test_modules: None) -> None: ) database_config = config """ - module_name = _create_module(Path(temp_dir), config_module) - - # Change to the temp directory - original_cwd = os.getcwd() - try: - os.chdir(temp_dir) - result = runner.invoke( - add_migration_commands(), ["--config", f"{module_name}.database_config", "show-config"] - ) - finally: - os.chdir(original_cwd) - - assert result.exit_code == 0 - assert "path_test" in result.output - assert "custom_migrations" in result.output - assert "Migration Enabled" in result.output or "migrations enabled" in result.output + module_name = _create_module(tmp_path, config_module) + + # Change to the temp directory + monkeypatch.chdir(tmp_path) + result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.database_config", "show-config"]) + + assert result.exit_code == 0 + assert "path_test" in result.output + assert "custom_migrations" in result.output + assert "Migration Enabled" in result.output or "migrations enabled" in result.output diff --git a/tests/unit/test_cli/test_migration_commands.py b/tests/unit/test_cli/test_migration_commands.py index 7c89f922..c7987bfa 100644 --- a/tests/unit/test_cli/test_migration_commands.py +++ b/tests/unit/test_cli/test_migration_commands.py @@ -1,8 +1,6 @@ """Tests for CLI migration commands functionality.""" -import os import sys -import tempfile import uuid from collections.abc import Iterator from pathlib import Path @@ -39,15 +37,12 @@ def _create_module(content: str, directory: "Path") -> str: return module_name -def test_show_config_command(cleanup_test_modules: None) -> None: +def test_show_config_command(tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None) -> None: """Test show-config command displays migration configurations.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -61,27 +56,23 @@ def get_config(): ) return config """ - module_name = _create_module(config_module, Path(temp_dir)) + module_name = _create_module(config_module, tmp_path) - result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_config", "show-config"]) + result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_config", "show-config"]) - finally: - os.chdir(original_dir) + assert result.exit_code == 0 + assert "migration_test" in result.output + assert "Migration Enabled" in result.output or "SqliteConfig" in result.output - assert result.exit_code == 0 - assert "migration_test" in result.output - assert "Migration Enabled" in result.output or "SqliteConfig" in result.output - -def test_show_config_with_multiple_configs(cleanup_test_modules: None) -> None: +def test_show_config_with_multiple_configs( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test show-config with multiple migration configurations.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig from sqlspec.adapters.duckdb.config import DuckDBConfig @@ -100,28 +91,22 @@ def get_configs(): return [sqlite_config, duckdb_config] """ - module_name = _create_module(config_module, Path(temp_dir)) + module_name = _create_module(config_module, tmp_path) - result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_configs", "show-config"]) + result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_configs", "show-config"]) - finally: - os.chdir(original_dir) + assert result.exit_code == 0 + assert "sqlite_migrations" in result.output + assert "duckdb_migrations" in result.output + assert "2 configuration(s)" in result.output - assert result.exit_code == 0 - assert "sqlite_migrations" in result.output - assert "duckdb_migrations" in result.output - assert "2 configuration(s)" in result.output - -def test_show_config_no_migrations(cleanup_test_modules: None) -> None: +def test_show_config_no_migrations(tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None) -> None: """Test show-config when no migrations are configured.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -132,34 +117,30 @@ def get_config(): ) return config """ - module_name = _create_module(config_module, Path(temp_dir)) - - result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_config", "show-config"]) + module_name = _create_module(config_module, tmp_path) - finally: - os.chdir(original_dir) + result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_config", "show-config"]) - assert result.exit_code == 0 - assert ( - "No configurations with migrations detected" in result.output or "no_migrations" in result.output - ) # Depends on validation logic + assert result.exit_code == 0 + assert ( + "No configurations with migrations detected" in result.output or "no_migrations" in result.output + ) # Depends on validation logic @patch("sqlspec.migrations.commands.create_migration_commands") -def test_show_current_revision_command(mock_create_commands: "Mock", cleanup_test_modules: None) -> None: +def test_show_current_revision_command( + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test show-current-revision command.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) # Mock the migration commands mock_commands = Mock() mock_commands.current = Mock(return_value=None) # Sync function mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -170,33 +151,27 @@ def get_config(): ) return config """ - module_name = _create_module(config_module, Path(temp_dir)) - - result = runner.invoke( - add_migration_commands(), ["--config", f"{module_name}.get_config", "show-current-revision"] - ) + module_name = _create_module(config_module, tmp_path) - finally: - os.chdir(original_dir) + result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_config", "show-current-revision"]) - assert result.exit_code == 0 - mock_commands.current.assert_called_once_with(verbose=False) + assert result.exit_code == 0 + mock_commands.current.assert_called_once_with(verbose=False) @patch("sqlspec.migrations.commands.create_migration_commands") -def test_show_current_revision_verbose(mock_create_commands: "Mock", cleanup_test_modules: None) -> None: +def test_show_current_revision_verbose( + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test show-current-revision command with verbose flag.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) mock_commands = Mock() mock_commands.current = Mock(return_value=None) mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -207,34 +182,29 @@ def get_config(): ) return config """ - module_name = _create_module(config_module, Path(temp_dir)) - - result = runner.invoke( - add_migration_commands(), - ["--config", f"{module_name}.get_config", "show-current-revision", "--verbose"], - ) + module_name = _create_module(config_module, tmp_path) - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), ["--config", f"{module_name}.get_config", "show-current-revision", "--verbose"] + ) - assert result.exit_code == 0 - mock_commands.current.assert_called_once_with(verbose=True) + assert result.exit_code == 0 + mock_commands.current.assert_called_once_with(verbose=True) @patch("sqlspec.migrations.commands.create_migration_commands") -def test_init_command(mock_create_commands: "Mock", cleanup_test_modules: None) -> None: +def test_init_command( + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test init command for initializing migrations.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) mock_commands = Mock() mock_commands.init = Mock(return_value=None) mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -243,33 +213,27 @@ def get_config(): config.migration_config = {"script_location": "test_migrations"} return config """ - module_name = _create_module(config_module, Path(temp_dir)) + module_name = _create_module(config_module, tmp_path) - result = runner.invoke( - add_migration_commands(), ["--config", f"{module_name}.get_config", "init", "--no-prompt"] - ) + result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_config", "init", "--no-prompt"]) - finally: - os.chdir(original_dir) - - assert result.exit_code == 0 - mock_commands.init.assert_called_once_with(directory="test_migrations", package=True) + assert result.exit_code == 0 + mock_commands.init.assert_called_once_with(directory="test_migrations", package=True) @patch("sqlspec.migrations.commands.create_migration_commands") -def test_init_command_custom_directory(mock_create_commands: "Mock", cleanup_test_modules: None) -> None: +def test_init_command_custom_directory( + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test init command with custom directory.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) mock_commands = Mock() mock_commands.init = Mock(return_value=None) mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -278,34 +242,29 @@ def get_config(): config.migration_config = {"script_location": "migrations"} return config """ - module_name = _create_module(config_module, Path(temp_dir)) + module_name = _create_module(config_module, tmp_path) - result = runner.invoke( - add_migration_commands(), - ["--config", f"{module_name}.get_config", "init", "custom_migrations", "--no-prompt"], - ) - - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), ["--config", f"{module_name}.get_config", "init", "custom_migrations", "--no-prompt"] + ) - assert result.exit_code == 0 - mock_commands.init.assert_called_once_with(directory="custom_migrations", package=True) + assert result.exit_code == 0 + mock_commands.init.assert_called_once_with(directory="custom_migrations", package=True) @patch("sqlspec.migrations.commands.create_migration_commands") -def test_create_migration_command(mock_create_commands: "Mock", cleanup_test_modules: None) -> None: +def test_create_migration_command( + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test create-migration command.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) mock_commands = Mock() mock_commands.revision = Mock(return_value=None) mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -314,34 +273,30 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - module_name = _create_module(config_module, Path(temp_dir)) + module_name = _create_module(config_module, tmp_path) - result = runner.invoke( - add_migration_commands(), - ["--config", f"{module_name}.get_config", "create-migration", "-m", "test migration", "--no-prompt"], - ) - - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), + ["--config", f"{module_name}.get_config", "create-migration", "-m", "test migration", "--no-prompt"], + ) - assert result.exit_code == 0 - mock_commands.revision.assert_called_once_with(message="test migration", file_type=None) + assert result.exit_code == 0 + mock_commands.revision.assert_called_once_with(message="test migration", file_type=None) @patch("sqlspec.migrations.commands.create_migration_commands") -def test_make_migration_alias(mock_create_commands: "Mock", cleanup_test_modules: None) -> None: +def test_make_migration_alias( + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test make-migration alias for backward compatibility.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) mock_commands = Mock() mock_commands.revision = Mock(return_value=None) mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -350,33 +305,29 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - module_name = _create_module(config_module, Path(temp_dir)) - - result = runner.invoke( - add_migration_commands(), - ["--config", f"{module_name}.get_config", "make-migration", "-m", "test migration", "--no-prompt"], - ) + module_name = _create_module(config_module, tmp_path) - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), + ["--config", f"{module_name}.get_config", "make-migration", "-m", "test migration", "--no-prompt"], + ) - assert result.exit_code == 0 - mock_commands.revision.assert_called_once_with(message="test migration", file_type=None) + assert result.exit_code == 0 + mock_commands.revision.assert_called_once_with(message="test migration", file_type=None) @patch("sqlspec.migrations.commands.create_migration_commands") -def test_create_migration_command_with_format(mock_create_commands: "Mock", cleanup_test_modules: None) -> None: +def test_create_migration_command_with_format( + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: runner = CliRunner() + monkeypatch.chdir(tmp_path) mock_commands = Mock() mock_commands.revision = Mock(return_value=None) mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -385,44 +336,38 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - module_name = _create_module(config_module, Path(temp_dir)) - - result = runner.invoke( - add_migration_commands(), - [ - "--config", - f"{module_name}.get_config", - "create-migration", - "-m", - "test migration", - "--format", - "py", - "--no-prompt", - ], - ) - - finally: - os.chdir(original_dir) + module_name = _create_module(config_module, tmp_path) + + result = runner.invoke( + add_migration_commands(), + [ + "--config", + f"{module_name}.get_config", + "create-migration", + "-m", + "test migration", + "--format", + "py", + "--no-prompt", + ], + ) - assert result.exit_code == 0 - mock_commands.revision.assert_called_once_with(message="test migration", file_type="py") + assert result.exit_code == 0 + mock_commands.revision.assert_called_once_with(message="test migration", file_type="py") @patch("sqlspec.migrations.commands.create_migration_commands") def test_create_migration_command_with_file_type_alias( - mock_create_commands: "Mock", cleanup_test_modules: None + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None ) -> None: runner = CliRunner() + monkeypatch.chdir(tmp_path) mock_commands = Mock() mock_commands.revision = Mock(return_value=None) mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -431,43 +376,39 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - module_name = _create_module(config_module, Path(temp_dir)) - - result = runner.invoke( - add_migration_commands(), - [ - "--config", - f"{module_name}.get_config", - "create-migration", - "-m", - "test migration", - "--file-type", - "sql", - "--no-prompt", - ], - ) - - finally: - os.chdir(original_dir) + module_name = _create_module(config_module, tmp_path) + + result = runner.invoke( + add_migration_commands(), + [ + "--config", + f"{module_name}.get_config", + "create-migration", + "-m", + "test migration", + "--file-type", + "sql", + "--no-prompt", + ], + ) - assert result.exit_code == 0 - mock_commands.revision.assert_called_once_with(message="test migration", file_type="sql") + assert result.exit_code == 0 + mock_commands.revision.assert_called_once_with(message="test migration", file_type="sql") @patch("sqlspec.migrations.commands.create_migration_commands") -def test_upgrade_command(mock_create_commands: "Mock", cleanup_test_modules: None) -> None: +def test_upgrade_command( + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test upgrade command.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) mock_commands = Mock() mock_commands.upgrade = Mock(return_value=None) mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -476,33 +417,29 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - module_name = _create_module(config_module, Path(temp_dir)) - - result = runner.invoke( - add_migration_commands(), ["--config", f"{module_name}.get_config", "upgrade", "--no-prompt"] - ) + module_name = _create_module(config_module, tmp_path) - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), ["--config", f"{module_name}.get_config", "upgrade", "--no-prompt"] + ) - assert result.exit_code == 0 - mock_commands.upgrade.assert_called_once_with(revision="head", auto_sync=True, dry_run=False) + assert result.exit_code == 0 + mock_commands.upgrade.assert_called_once_with(revision="head", auto_sync=True, dry_run=False) @patch("sqlspec.migrations.commands.create_migration_commands") -def test_upgrade_command_specific_revision(mock_create_commands: "Mock", cleanup_test_modules: None) -> None: +def test_upgrade_command_specific_revision( + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test upgrade command with specific revision.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) mock_commands = Mock() mock_commands.upgrade = Mock(return_value=None) mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -511,33 +448,29 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - module_name = _create_module(config_module, Path(temp_dir)) + module_name = _create_module(config_module, tmp_path) - result = runner.invoke( - add_migration_commands(), ["--config", f"{module_name}.get_config", "upgrade", "abc123", "--no-prompt"] - ) - - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), ["--config", f"{module_name}.get_config", "upgrade", "abc123", "--no-prompt"] + ) - assert result.exit_code == 0 - mock_commands.upgrade.assert_called_once_with(revision="abc123", auto_sync=True, dry_run=False) + assert result.exit_code == 0 + mock_commands.upgrade.assert_called_once_with(revision="abc123", auto_sync=True, dry_run=False) @patch("sqlspec.migrations.commands.create_migration_commands") -def test_downgrade_command(mock_create_commands: "Mock", cleanup_test_modules: None) -> None: +def test_downgrade_command( + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test downgrade command.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) mock_commands = Mock() mock_commands.downgrade = Mock(return_value=None) mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -546,33 +479,29 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - module_name = _create_module(config_module, Path(temp_dir)) - - result = runner.invoke( - add_migration_commands(), ["--config", f"{module_name}.get_config", "downgrade", "--no-prompt"] - ) + module_name = _create_module(config_module, tmp_path) - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), ["--config", f"{module_name}.get_config", "downgrade", "--no-prompt"] + ) - assert result.exit_code == 0 - mock_commands.downgrade.assert_called_once_with(revision="-1", dry_run=False) + assert result.exit_code == 0 + mock_commands.downgrade.assert_called_once_with(revision="-1", dry_run=False) @patch("sqlspec.migrations.commands.create_migration_commands") -def test_stamp_command(mock_create_commands: "Mock", cleanup_test_modules: None) -> None: +def test_stamp_command( + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test stamp command.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) mock_commands = Mock() mock_commands.stamp = Mock(return_value=None) mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -581,33 +510,27 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - module_name = _create_module(config_module, Path(temp_dir)) + module_name = _create_module(config_module, tmp_path) - result = runner.invoke( - add_migration_commands(), ["--config", f"{module_name}.get_config", "stamp", "abc123"] - ) + result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_config", "stamp", "abc123"]) - finally: - os.chdir(original_dir) - - assert result.exit_code == 0 - mock_commands.stamp.assert_called_once_with(revision="abc123") + assert result.exit_code == 0 + mock_commands.stamp.assert_called_once_with(revision="abc123") @patch("sqlspec.migrations.commands.create_migration_commands") -def test_multi_config_operations(mock_create_commands: "Mock", cleanup_test_modules: None) -> None: +def test_multi_config_operations( + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test multi-configuration operations with include/exclude filters.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) mock_commands = Mock() mock_commands.current = Mock(return_value=None) mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig from sqlspec.adapters.duckdb.config import DuckDBConfig @@ -622,35 +545,31 @@ def get_configs(): return [sqlite_config, duckdb_config] """ - module_name = _create_module(config_module, Path(temp_dir)) + module_name = _create_module(config_module, tmp_path) - result = runner.invoke( - add_migration_commands(), - ["--config", f"{module_name}.get_configs", "show-current-revision", "--include", "sqlite_multi"], - ) - - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), + ["--config", f"{module_name}.get_configs", "show-current-revision", "--include", "sqlite_multi"], + ) - assert result.exit_code == 0 - # Should process only the included configuration - assert "sqlite_multi" in result.output + assert result.exit_code == 0 + # Should process only the included configuration + assert "sqlite_multi" in result.output @patch("sqlspec.migrations.commands.create_migration_commands") -def test_dry_run_operations(mock_create_commands: "Mock", cleanup_test_modules: None) -> None: +def test_dry_run_operations( + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test dry-run operations show what would be executed.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) mock_commands = Mock() mock_commands.upgrade = Mock(return_value=None) mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_configs(): @@ -664,31 +583,23 @@ def get_configs(): return [config1, config2] """ - module_name = _create_module(config_module, Path(temp_dir)) - - result = runner.invoke( - add_migration_commands(), ["--config", f"{module_name}.get_configs", "upgrade", "--dry-run"] - ) + module_name = _create_module(config_module, tmp_path) - finally: - os.chdir(original_dir) + result = runner.invoke(add_migration_commands(), ["--config", f"{module_name}.get_configs", "upgrade", "--dry-run"]) - assert result.exit_code == 0 - assert "Dry run" in result.output - assert "Would upgrade" in result.output - # Should not actually call the upgrade method with dry-run - mock_commands.upgrade.assert_not_called() + assert result.exit_code == 0 + assert "Dry run" in result.output + assert "Would upgrade" in result.output + # Should not actually call the upgrade method with dry-run + mock_commands.upgrade.assert_not_called() -def test_execution_mode_reporting(cleanup_test_modules: None) -> None: +def test_execution_mode_reporting(tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None) -> None: """Test that execution mode is reported when specified.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -697,34 +608,30 @@ def get_config(): config.migration_config = {"enabled": True} return config """ - module_name = _create_module(config_module, Path(temp_dir)) - - with patch("sqlspec.migrations.commands.create_migration_commands") as mock_create: - mock_commands = Mock() - mock_commands.upgrade = Mock(return_value=None) - mock_create.return_value = mock_commands + module_name = _create_module(config_module, tmp_path) - result = runner.invoke( - add_migration_commands(), - ["--config", f"{module_name}.get_config", "upgrade", "--execution-mode", "sync", "--no-prompt"], - ) + with patch("sqlspec.migrations.commands.create_migration_commands") as mock_create: + mock_commands = Mock() + mock_commands.upgrade = Mock(return_value=None) + mock_create.return_value = mock_commands - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), + ["--config", f"{module_name}.get_config", "upgrade", "--execution-mode", "sync", "--no-prompt"], + ) - assert result.exit_code == 0 - assert "Execution mode: sync" in result.output + assert result.exit_code == 0 + assert "Execution mode: sync" in result.output -def test_bind_key_filtering_single_config(cleanup_test_modules: None) -> None: +def test_bind_key_filtering_single_config( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test --bind-key filtering with single config.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_config(): @@ -734,29 +641,25 @@ def get_config(): migration_config={"enabled": True, "script_location": "migrations"} ) """ - module_name = _create_module(config_module, Path(temp_dir)) + module_name = _create_module(config_module, tmp_path) - result = runner.invoke( - add_migration_commands(), - ["--config", f"{module_name}.get_config", "show-config", "--bind-key", "target_config"], - ) - - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), + ["--config", f"{module_name}.get_config", "show-config", "--bind-key", "target_config"], + ) - assert result.exit_code == 0 - assert "target_config" in result.output + assert result.exit_code == 0 + assert "target_config" in result.output -def test_bind_key_filtering_multiple_configs(cleanup_test_modules: None) -> None: +def test_bind_key_filtering_multiple_configs( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test --bind-key filtering with multiple configs.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig from sqlspec.adapters.duckdb.config import DuckDBConfig @@ -781,34 +684,29 @@ def get_configs(): return [sqlite_config, duckdb_config, postgres_config] """ - module_name = _create_module(config_module, Path(temp_dir)) - - # Test filtering for sqlite_db only - result = runner.invoke( - add_migration_commands(), - ["--config", f"{module_name}.get_configs", "show-config", "--bind-key", "sqlite_db"], - ) + module_name = _create_module(config_module, tmp_path) - finally: - os.chdir(original_dir) + # Test filtering for sqlite_db only + result = runner.invoke( + add_migration_commands(), ["--config", f"{module_name}.get_configs", "show-config", "--bind-key", "sqlite_db"] + ) - assert result.exit_code == 0 - assert "sqlite_db" in result.output - # Should only show one config, not all three - assert "Found 1 configuration(s)" in result.output or "sqlite_migrations" in result.output - assert "duckdb_db" not in result.output - assert "postgres_db" not in result.output + assert result.exit_code == 0 + assert "sqlite_db" in result.output + # Should only show one config, not all three + assert "Found 1 configuration(s)" in result.output or "sqlite_migrations" in result.output + assert "duckdb_db" not in result.output + assert "postgres_db" not in result.output -def test_bind_key_filtering_nonexistent_key(cleanup_test_modules: None) -> None: +def test_bind_key_filtering_nonexistent_key( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test --bind-key filtering with nonexistent bind key.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig def get_configs(): @@ -820,34 +718,29 @@ def get_configs(): ) ] """ - module_name = _create_module(config_module, Path(temp_dir)) - - result = runner.invoke( - add_migration_commands(), - ["--config", f"{module_name}.get_configs", "show-config", "--bind-key", "nonexistent"], - ) + module_name = _create_module(config_module, tmp_path) - finally: - os.chdir(original_dir) + result = runner.invoke( + add_migration_commands(), ["--config", f"{module_name}.get_configs", "show-config", "--bind-key", "nonexistent"] + ) - assert result.exit_code == 1 - assert "No config found for bind key: nonexistent" in result.output + assert result.exit_code == 1 + assert "No config found for bind key: nonexistent" in result.output @patch("sqlspec.migrations.commands.create_migration_commands") -def test_bind_key_filtering_with_migration_commands(mock_create_commands: "Mock", cleanup_test_modules: None) -> None: +def test_bind_key_filtering_with_migration_commands( + mock_create_commands: "Mock", tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cleanup_test_modules: None +) -> None: """Test --bind-key filtering works with actual migration commands.""" runner = CliRunner() + monkeypatch.chdir(tmp_path) mock_commands = Mock() mock_commands.upgrade = Mock(return_value=None) mock_create_commands.return_value = mock_commands - with tempfile.TemporaryDirectory() as temp_dir: - original_dir = os.getcwd() - os.chdir(temp_dir) - try: - config_module = """ + config_module = """ from sqlspec.adapters.sqlite.config import SqliteConfig from sqlspec.adapters.duckdb.config import DuckDBConfig @@ -865,23 +758,13 @@ def get_multi_configs(): ) ] """ - module_name = _create_module(config_module, Path(temp_dir)) - - result = runner.invoke( - add_migration_commands(), - [ - "--config", - f"{module_name}.get_multi_configs", - "upgrade", - "--bind-key", - "analytics_db", - "--no-prompt", - ], - ) - - finally: - os.chdir(original_dir) - - assert result.exit_code == 0 - # Should only process the analytics_db config - mock_commands.upgrade.assert_called_once_with(revision="head", auto_sync=True, dry_run=False) + module_name = _create_module(config_module, tmp_path) + + result = runner.invoke( + add_migration_commands(), + ["--config", f"{module_name}.get_multi_configs", "upgrade", "--bind-key", "analytics_db", "--no-prompt"], + ) + + assert result.exit_code == 0 + # Should only process the analytics_db config + mock_commands.upgrade.assert_called_once_with(revision="head", auto_sync=True, dry_run=False) diff --git a/tests/unit/test_config/test_migration_methods.py b/tests/unit/test_config/test_migration_methods.py index a29f7d19..773cf5c4 100644 --- a/tests/unit/test_config/test_migration_methods.py +++ b/tests/unit/test_config/test_migration_methods.py @@ -16,7 +16,6 @@ - AsyncDatabaseConfig (async, pooled) """ -import tempfile from pathlib import Path from unittest.mock import patch @@ -74,451 +73,397 @@ def test_no_pool_async_config_has_migration_methods() -> None: assert hasattr(NoPoolAsyncConfig, "fix_migrations") -def test_sqlite_config_migrate_up_calls_commands() -> None: +def test_sqlite_config_migrate_up_calls_commands(tmp_path: Path) -> None: """Test that SqliteConfig.migrate_up() delegates to SyncMigrationCommands.upgrade().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} - ) + config = SqliteConfig(pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)}) - with patch.object(SyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: - config.migrate_up(revision="head", allow_missing=True, auto_sync=False, dry_run=True) + with patch.object(SyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: + config.migrate_up(revision="head", allow_missing=True, auto_sync=False, dry_run=True) - mock_upgrade.assert_called_once_with("head", True, False, True) + mock_upgrade.assert_called_once_with("head", True, False, True) -def test_sqlite_config_migrate_down_calls_commands() -> None: +def test_sqlite_config_migrate_down_calls_commands(tmp_path: Path) -> None: """Test that SqliteConfig.migrate_down() delegates to SyncMigrationCommands.downgrade().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} - ) + config = SqliteConfig(pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)}) - with patch.object(SyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: - config.migrate_down(revision="-2", dry_run=True) + with patch.object(SyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: + config.migrate_down(revision="-2", dry_run=True) - mock_downgrade.assert_called_once_with("-2", dry_run=True) + mock_downgrade.assert_called_once_with("-2", dry_run=True) -def test_sqlite_config_get_current_migration_calls_commands() -> None: +def test_sqlite_config_get_current_migration_calls_commands(tmp_path: Path) -> None: """Test that SqliteConfig.get_current_migration() delegates to SyncMigrationCommands.current().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} - ) + config = SqliteConfig(pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)}) - with patch.object(SyncMigrationCommands, "current", return_value="0001") as mock_current: - result = config.get_current_migration(verbose=True) + with patch.object(SyncMigrationCommands, "current", return_value="0001") as mock_current: + result = config.get_current_migration(verbose=True) - mock_current.assert_called_once_with(verbose=True) - assert result == "0001" + mock_current.assert_called_once_with(verbose=True) + assert result == "0001" -def test_sqlite_config_create_migration_calls_commands() -> None: +def test_sqlite_config_create_migration_calls_commands(tmp_path: Path) -> None: """Test that SqliteConfig.create_migration() delegates to SyncMigrationCommands.revision().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} - ) + config = SqliteConfig(pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)}) - with patch.object(SyncMigrationCommands, "revision", return_value=None) as mock_revision: - config.create_migration(message="test migration", file_type="py") + with patch.object(SyncMigrationCommands, "revision", return_value=None) as mock_revision: + config.create_migration(message="test migration", file_type="py") - mock_revision.assert_called_once_with("test migration", "py") + mock_revision.assert_called_once_with("test migration", "py") -def test_sqlite_config_init_migrations_calls_commands() -> None: +def test_sqlite_config_init_migrations_calls_commands(tmp_path: Path) -> None: """Test that SqliteConfig.init_migrations() delegates to SyncMigrationCommands.init().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} - ) + config = SqliteConfig(pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)}) - with patch.object(SyncMigrationCommands, "init", return_value=None) as mock_init: - config.init_migrations(directory=str(migration_dir), package=False) + with patch.object(SyncMigrationCommands, "init", return_value=None) as mock_init: + config.init_migrations(directory=str(migration_dir), package=False) - mock_init.assert_called_once_with(str(migration_dir), False) + mock_init.assert_called_once_with(str(migration_dir), False) -def test_sqlite_config_init_migrations_uses_default_directory() -> None: +def test_sqlite_config_init_migrations_uses_default_directory(tmp_path: Path) -> None: """Test that SqliteConfig.init_migrations() uses script_location when directory not provided.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} - ) + config = SqliteConfig(pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)}) - with patch.object(SyncMigrationCommands, "init", return_value=None) as mock_init: - config.init_migrations(package=True) + with patch.object(SyncMigrationCommands, "init", return_value=None) as mock_init: + config.init_migrations(package=True) - mock_init.assert_called_once_with(str(migration_dir), True) + mock_init.assert_called_once_with(str(migration_dir), True) -def test_sqlite_config_stamp_migration_calls_commands() -> None: +def test_sqlite_config_stamp_migration_calls_commands(tmp_path: Path) -> None: """Test that SqliteConfig.stamp_migration() delegates to SyncMigrationCommands.stamp().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} - ) + config = SqliteConfig(pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)}) - with patch.object(SyncMigrationCommands, "stamp", return_value=None) as mock_stamp: - config.stamp_migration(revision="0001") + with patch.object(SyncMigrationCommands, "stamp", return_value=None) as mock_stamp: + config.stamp_migration(revision="0001") - mock_stamp.assert_called_once_with("0001") + mock_stamp.assert_called_once_with("0001") -def test_sqlite_config_fix_migrations_calls_commands() -> None: +def test_sqlite_config_fix_migrations_calls_commands(tmp_path: Path) -> None: """Test that SqliteConfig.fix_migrations() delegates to SyncMigrationCommands.fix().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} - ) + config = SqliteConfig(pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)}) - with patch.object(SyncMigrationCommands, "fix", return_value=None) as mock_fix: - config.fix_migrations(dry_run=True, update_database=False, yes=True) + with patch.object(SyncMigrationCommands, "fix", return_value=None) as mock_fix: + config.fix_migrations(dry_run=True, update_database=False, yes=True) - mock_fix.assert_called_once_with(True, False, True) + mock_fix.assert_called_once_with(True, False, True) @pytest.mark.asyncio -async def test_asyncpg_config_migrate_up_calls_commands() -> None: +async def test_asyncpg_config_migrate_up_calls_commands(tmp_path: Path) -> None: """Test that AsyncpgConfig.migrate_up() delegates to AsyncMigrationCommands.upgrade().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} - ) + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(AsyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: - await config.migrate_up(revision="0002", allow_missing=False, auto_sync=True, dry_run=False) + with patch.object(AsyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: + await config.migrate_up(revision="0002", allow_missing=False, auto_sync=True, dry_run=False) - mock_upgrade.assert_called_once_with("0002", False, True, False) + mock_upgrade.assert_called_once_with("0002", False, True, False) @pytest.mark.asyncio -async def test_asyncpg_config_migrate_down_calls_commands() -> None: +async def test_asyncpg_config_migrate_down_calls_commands(tmp_path: Path) -> None: """Test that AsyncpgConfig.migrate_down() delegates to AsyncMigrationCommands.downgrade().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} - ) + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(AsyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: - await config.migrate_down(revision="base", dry_run=False) + with patch.object(AsyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: + await config.migrate_down(revision="base", dry_run=False) - mock_downgrade.assert_called_once_with("base", dry_run=False) + mock_downgrade.assert_called_once_with("base", dry_run=False) @pytest.mark.asyncio -async def test_asyncpg_config_get_current_migration_calls_commands() -> None: +async def test_asyncpg_config_get_current_migration_calls_commands(tmp_path: Path) -> None: """Test that AsyncpgConfig.get_current_migration() delegates to AsyncMigrationCommands.current().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} - ) + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(AsyncMigrationCommands, "current", return_value="0002") as mock_current: - result = await config.get_current_migration(verbose=False) + with patch.object(AsyncMigrationCommands, "current", return_value="0002") as mock_current: + result = await config.get_current_migration(verbose=False) - mock_current.assert_called_once_with(verbose=False) - assert result == "0002" + mock_current.assert_called_once_with(verbose=False) + assert result == "0002" @pytest.mark.asyncio -async def test_asyncpg_config_create_migration_calls_commands() -> None: +async def test_asyncpg_config_create_migration_calls_commands(tmp_path: Path) -> None: """Test that AsyncpgConfig.create_migration() delegates to AsyncMigrationCommands.revision().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} - ) + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(AsyncMigrationCommands, "revision", return_value=None) as mock_revision: - await config.create_migration(message="add users table", file_type="sql") + with patch.object(AsyncMigrationCommands, "revision", return_value=None) as mock_revision: + await config.create_migration(message="add users table", file_type="sql") - mock_revision.assert_called_once_with("add users table", "sql") + mock_revision.assert_called_once_with("add users table", "sql") @pytest.mark.asyncio -async def test_asyncpg_config_init_migrations_calls_commands() -> None: +async def test_asyncpg_config_init_migrations_calls_commands(tmp_path: Path) -> None: """Test that AsyncpgConfig.init_migrations() delegates to AsyncMigrationCommands.init().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} - ) + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(AsyncMigrationCommands, "init", return_value=None) as mock_init: - await config.init_migrations(directory=str(migration_dir), package=True) + with patch.object(AsyncMigrationCommands, "init", return_value=None) as mock_init: + await config.init_migrations(directory=str(migration_dir), package=True) - mock_init.assert_called_once_with(str(migration_dir), True) + mock_init.assert_called_once_with(str(migration_dir), True) @pytest.mark.asyncio -async def test_asyncpg_config_stamp_migration_calls_commands() -> None: +async def test_asyncpg_config_stamp_migration_calls_commands(tmp_path: Path) -> None: """Test that AsyncpgConfig.stamp_migration() delegates to AsyncMigrationCommands.stamp().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} - ) + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(AsyncMigrationCommands, "stamp", return_value=None) as mock_stamp: - await config.stamp_migration(revision="0003") + with patch.object(AsyncMigrationCommands, "stamp", return_value=None) as mock_stamp: + await config.stamp_migration(revision="0003") - mock_stamp.assert_called_once_with("0003") + mock_stamp.assert_called_once_with("0003") @pytest.mark.asyncio -async def test_asyncpg_config_fix_migrations_calls_commands() -> None: +async def test_asyncpg_config_fix_migrations_calls_commands(tmp_path: Path) -> None: """Test that AsyncpgConfig.fix_migrations() delegates to AsyncMigrationCommands.fix().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} - ) + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(AsyncMigrationCommands, "fix", return_value=None) as mock_fix: - await config.fix_migrations(dry_run=False, update_database=True, yes=False) + with patch.object(AsyncMigrationCommands, "fix", return_value=None) as mock_fix: + await config.fix_migrations(dry_run=False, update_database=True, yes=False) - mock_fix.assert_called_once_with(False, True, False) + mock_fix.assert_called_once_with(False, True, False) -def test_duckdb_pooled_config_migrate_up_calls_commands() -> None: +def test_duckdb_pooled_config_migrate_up_calls_commands(tmp_path: Path) -> None: """Test that DuckDBConfig.migrate_up() delegates to SyncMigrationCommands.upgrade().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = DuckDBConfig( - pool_config={"database": ":memory:"}, migration_config={"script_location": str(migration_dir)} - ) + config = DuckDBConfig( + pool_config={"database": ":memory:"}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(SyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: - config.migrate_up(revision="head", allow_missing=False, auto_sync=True, dry_run=False) + with patch.object(SyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: + config.migrate_up(revision="head", allow_missing=False, auto_sync=True, dry_run=False) - mock_upgrade.assert_called_once_with("head", False, True, False) + mock_upgrade.assert_called_once_with("head", False, True, False) -def test_duckdb_pooled_config_get_current_migration_calls_commands() -> None: +def test_duckdb_pooled_config_get_current_migration_calls_commands(tmp_path: Path) -> None: """Test that DuckDBConfig.get_current_migration() delegates to SyncMigrationCommands.current().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = DuckDBConfig( - pool_config={"database": ":memory:"}, migration_config={"script_location": str(migration_dir)} - ) + config = DuckDBConfig( + pool_config={"database": ":memory:"}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(SyncMigrationCommands, "current", return_value=None) as mock_current: - result = config.get_current_migration(verbose=False) + with patch.object(SyncMigrationCommands, "current", return_value=None) as mock_current: + result = config.get_current_migration(verbose=False) - mock_current.assert_called_once_with(verbose=False) - assert result is None + mock_current.assert_called_once_with(verbose=False) + assert result is None @pytest.mark.asyncio -async def test_aiosqlite_async_config_migrate_up_calls_commands() -> None: +async def test_aiosqlite_async_config_migrate_up_calls_commands(tmp_path: Path) -> None: """Test that AiosqliteConfig.migrate_up() delegates to AsyncMigrationCommands.upgrade().""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = AiosqliteConfig( - pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} - ) + config = AiosqliteConfig( + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(AsyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: - await config.migrate_up(revision="head", allow_missing=True, auto_sync=True, dry_run=True) + with patch.object(AsyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: + await config.migrate_up(revision="head", allow_missing=True, auto_sync=True, dry_run=True) - mock_upgrade.assert_called_once_with("head", True, True, True) + mock_upgrade.assert_called_once_with("head", True, True, True) -def test_migrate_up_default_parameters_sync() -> None: +def test_migrate_up_default_parameters_sync(tmp_path: Path) -> None: """Test that migrate_up() uses correct default parameter values for sync configs.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} - ) + config = SqliteConfig(pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)}) - with patch.object(SyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: - config.migrate_up() + with patch.object(SyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: + config.migrate_up() - mock_upgrade.assert_called_once_with("head", False, True, False) + mock_upgrade.assert_called_once_with("head", False, True, False) @pytest.mark.asyncio -async def test_migrate_up_default_parameters_async() -> None: +async def test_migrate_up_default_parameters_async(tmp_path: Path) -> None: """Test that migrate_up() uses correct default parameter values for async configs.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} - ) + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(AsyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: - await config.migrate_up() + with patch.object(AsyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: + await config.migrate_up() - mock_upgrade.assert_called_once_with("head", False, True, False) + mock_upgrade.assert_called_once_with("head", False, True, False) -def test_migrate_down_default_parameters_sync() -> None: +def test_migrate_down_default_parameters_sync(tmp_path: Path) -> None: """Test that migrate_down() uses correct default parameter values for sync configs.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} - ) + config = SqliteConfig(pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)}) - with patch.object(SyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: - config.migrate_down() + with patch.object(SyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: + config.migrate_down() - mock_downgrade.assert_called_once_with("-1", dry_run=False) + mock_downgrade.assert_called_once_with("-1", dry_run=False) @pytest.mark.asyncio -async def test_migrate_down_default_parameters_async() -> None: +async def test_migrate_down_default_parameters_async(tmp_path: Path) -> None: """Test that migrate_down() uses correct default parameter values for async configs.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} - ) + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(AsyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: - await config.migrate_down() + with patch.object(AsyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: + await config.migrate_down() - mock_downgrade.assert_called_once_with("-1", dry_run=False) + mock_downgrade.assert_called_once_with("-1", dry_run=False) -def test_create_migration_default_file_type_sync() -> None: +def test_create_migration_default_file_type_sync(tmp_path: Path) -> None: """Test that create_migration() defaults to 'sql' file type for sync configs.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} - ) + config = SqliteConfig(pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)}) - with patch.object(SyncMigrationCommands, "revision", return_value=None) as mock_revision: - config.create_migration(message="test migration") + with patch.object(SyncMigrationCommands, "revision", return_value=None) as mock_revision: + config.create_migration(message="test migration") - mock_revision.assert_called_once_with("test migration", "sql") + mock_revision.assert_called_once_with("test migration", "sql") @pytest.mark.asyncio -async def test_create_migration_default_file_type_async() -> None: +async def test_create_migration_default_file_type_async(tmp_path: Path) -> None: """Test that create_migration() defaults to 'sql' file type for async configs.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} - ) + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(AsyncMigrationCommands, "revision", return_value=None) as mock_revision: - await config.create_migration(message="test migration") + with patch.object(AsyncMigrationCommands, "revision", return_value=None) as mock_revision: + await config.create_migration(message="test migration") - mock_revision.assert_called_once_with("test migration", "sql") + mock_revision.assert_called_once_with("test migration", "sql") -def test_init_migrations_default_package_sync() -> None: +def test_init_migrations_default_package_sync(tmp_path: Path) -> None: """Test that init_migrations() defaults to package=True for sync configs.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} - ) + config = SqliteConfig(pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)}) - with patch.object(SyncMigrationCommands, "init", return_value=None) as mock_init: - config.init_migrations(directory=str(migration_dir)) + with patch.object(SyncMigrationCommands, "init", return_value=None) as mock_init: + config.init_migrations(directory=str(migration_dir)) - mock_init.assert_called_once_with(str(migration_dir), True) + mock_init.assert_called_once_with(str(migration_dir), True) @pytest.mark.asyncio -async def test_init_migrations_default_package_async() -> None: +async def test_init_migrations_default_package_async(tmp_path: Path) -> None: """Test that init_migrations() defaults to package=True for async configs.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} - ) + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(AsyncMigrationCommands, "init", return_value=None) as mock_init: - await config.init_migrations(directory=str(migration_dir)) + with patch.object(AsyncMigrationCommands, "init", return_value=None) as mock_init: + await config.init_migrations(directory=str(migration_dir)) - mock_init.assert_called_once_with(str(migration_dir), True) + mock_init.assert_called_once_with(str(migration_dir), True) -def test_fix_migrations_default_parameters_sync() -> None: +def test_fix_migrations_default_parameters_sync(tmp_path: Path) -> None: """Test that fix_migrations() uses correct default parameter values for sync configs.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" - temp_db = str(Path(temp_dir) / "test.db") + migration_dir = tmp_path / "migrations" + temp_db = str(tmp_path / "test.db") - config = SqliteConfig( - pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} - ) + config = SqliteConfig(pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)}) - with patch.object(SyncMigrationCommands, "fix", return_value=None) as mock_fix: - config.fix_migrations() + with patch.object(SyncMigrationCommands, "fix", return_value=None) as mock_fix: + config.fix_migrations() - mock_fix.assert_called_once_with(False, True, False) + mock_fix.assert_called_once_with(False, True, False) @pytest.mark.asyncio -async def test_fix_migrations_default_parameters_async() -> None: +async def test_fix_migrations_default_parameters_async(tmp_path: Path) -> None: """Test that fix_migrations() uses correct default parameter values for async configs.""" - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} - ) + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} + ) - with patch.object(AsyncMigrationCommands, "fix", return_value=None) as mock_fix: - await config.fix_migrations() + with patch.object(AsyncMigrationCommands, "fix", return_value=None) as mock_fix: + await config.fix_migrations() - mock_fix.assert_called_once_with(False, True, False) + mock_fix.assert_called_once_with(False, True, False) diff --git a/tests/unit/test_loader/test_cache_integration.py b/tests/unit/test_loader/test_cache_integration.py index 7ff61b76..4e7ddb47 100644 --- a/tests/unit/test_loader/test_cache_integration.py +++ b/tests/unit/test_loader/test_cache_integration.py @@ -313,85 +313,83 @@ def test_cached_sqlfile_structure() -> None: assert set(cached_file.statement_names) == {"test_query_1", "test_query_2"} -def test_namespace_handling_in_cache() -> None: +def test_namespace_handling_in_cache(tmp_path: Path) -> None: """Test proper namespace handling in cached data.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - (base_path / "analytics").mkdir() - sql_file = base_path / "analytics" / "reports.sql" - sql_file.write_text(""" + (base_path / "analytics").mkdir() + sql_file = base_path / "analytics" / "reports.sql" + sql_file.write_text(""" -- name: user_report SELECT COUNT(*) FROM users; """) - loader = SQLFileLoader() + loader = SQLFileLoader() - with ( - patch("sqlspec.loader.get_cache_config") as mock_config, - patch("sqlspec.loader.get_cache") as mock_cache_factory, - ): - mock_cache_config = Mock() - mock_cache_config.compiled_cache_enabled = True - mock_config.return_value = mock_cache_config + with ( + patch("sqlspec.loader.get_cache_config") as mock_config, + patch("sqlspec.loader.get_cache") as mock_cache_factory, + ): + mock_cache_config = Mock() + mock_cache_config.compiled_cache_enabled = True + mock_config.return_value = mock_cache_config - mock_cache = Mock() - mock_cache.get.return_value = None - mock_cache_factory.return_value = mock_cache + mock_cache = Mock() + mock_cache.get.return_value = None + mock_cache_factory.return_value = mock_cache - loader.load_sql(base_path) + loader.load_sql(base_path) - assert "analytics.user_report" in loader._queries + assert "analytics.user_report" in loader._queries - mock_cache.put.assert_called() - cache_call_args = mock_cache.put.call_args[0] - assert cache_call_args[0] == "file" # First arg should be "file" namespace - cached_data = cache_call_args[2] # Third arg is the value in MultiLevelCache.put + mock_cache.put.assert_called() + cache_call_args = mock_cache.put.call_args[0] + assert cache_call_args[0] == "file" # First arg should be "file" namespace + cached_data = cache_call_args[2] # Third arg is the value in MultiLevelCache.put - assert isinstance(cached_data, CachedSQLFile) + assert isinstance(cached_data, CachedSQLFile) - assert "user_report" in cached_data.parsed_statements - assert "analytics.user_report" not in cached_data.parsed_statements + assert "user_report" in cached_data.parsed_statements + assert "analytics.user_report" not in cached_data.parsed_statements -def test_cache_restoration_with_namespace() -> None: +def test_cache_restoration_with_namespace(tmp_path: Path) -> None: """Test proper namespace restoration when loading from cache.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - (base_path / "reports").mkdir() - sql_file = base_path / "reports" / "daily.sql" - content = """ + (base_path / "reports").mkdir() + sql_file = base_path / "reports" / "daily.sql" + content = """ -- name: daily_users SELECT COUNT(*) FROM users WHERE date = CURRENT_DATE; """ - sql_file.write_text(content) + sql_file.write_text(content) - cached_sql_file = SQLFile(content, str(sql_file)) - cached_statements = { - "daily_users": NamedStatement("daily_users", "SELECT COUNT(*) FROM users WHERE date = CURRENT_DATE") - } - cached_file = CachedSQLFile(cached_sql_file, cached_statements) + cached_sql_file = SQLFile(content, str(sql_file)) + cached_statements = { + "daily_users": NamedStatement("daily_users", "SELECT COUNT(*) FROM users WHERE date = CURRENT_DATE") + } + cached_file = CachedSQLFile(cached_sql_file, cached_statements) - loader = SQLFileLoader() + loader = SQLFileLoader() - with ( - patch("sqlspec.loader.get_cache_config") as mock_config, - patch("sqlspec.loader.get_cache") as mock_cache_factory, - patch("sqlspec.loader.SQLFileLoader._is_file_unchanged", return_value=True), - ): - mock_cache_config = Mock() - mock_cache_config.compiled_cache_enabled = True - mock_config.return_value = mock_cache_config + with ( + patch("sqlspec.loader.get_cache_config") as mock_config, + patch("sqlspec.loader.get_cache") as mock_cache_factory, + patch("sqlspec.loader.SQLFileLoader._is_file_unchanged", return_value=True), + ): + mock_cache_config = Mock() + mock_cache_config.compiled_cache_enabled = True + mock_config.return_value = mock_cache_config - mock_cache = Mock() - mock_cache.get.return_value = cached_file - mock_cache_factory.return_value = mock_cache + mock_cache = Mock() + mock_cache.get.return_value = cached_file + mock_cache_factory.return_value = mock_cache - loader._load_single_file(sql_file, "reports") + loader._load_single_file(sql_file, "reports") - assert "reports.daily_users" in loader._queries - assert "daily_users" not in loader._queries + assert "reports.daily_users" in loader._queries + assert "daily_users" not in loader._queries def test_cache_clear_integration() -> None: diff --git a/tests/unit/test_loader/test_loading_patterns.py b/tests/unit/test_loader/test_loading_patterns.py index 8896933c..222761ec 100644 --- a/tests/unit/test_loader/test_loading_patterns.py +++ b/tests/unit/test_loader/test_loading_patterns.py @@ -8,8 +8,6 @@ - URI-based loading patterns """ -import tempfile -from collections.abc import Generator from pathlib import Path from unittest.mock import Mock @@ -23,17 +21,16 @@ @pytest.fixture -def temp_directory_structure() -> Generator[Path, None, None]: +def temp_directory_structure(tmp_path: Path) -> Path: """Create a temporary directory structure for testing.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - (base_path / "queries").mkdir() - (base_path / "queries" / "users").mkdir() - (base_path / "queries" / "products").mkdir() - (base_path / "migrations").mkdir() + (base_path / "queries").mkdir() + (base_path / "queries" / "users").mkdir() + (base_path / "queries" / "products").mkdir() + (base_path / "migrations").mkdir() - (base_path / "root_queries.sql").write_text(""" + (base_path / "root_queries.sql").write_text(""" -- name: global_health_check SELECT 'OK' as status; @@ -41,12 +38,12 @@ def temp_directory_structure() -> Generator[Path, None, None]: SELECT '1.0.0' as version; """) - (base_path / "queries" / "common.sql").write_text(""" + (base_path / "queries" / "common.sql").write_text(""" -- name: count_all_records SELECT COUNT(*) as total FROM information_schema.tables; """) - (base_path / "queries" / "users" / "user_queries.sql").write_text(""" + (base_path / "queries" / "users" / "user_queries.sql").write_text(""" -- name: get_user_by_id SELECT id, name, email FROM users WHERE id = :user_id; @@ -54,7 +51,7 @@ def temp_directory_structure() -> Generator[Path, None, None]: SELECT id, name FROM users WHERE active = true; """) - (base_path / "queries" / "products" / "product_queries.sql").write_text(""" + (base_path / "queries" / "products" / "product_queries.sql").write_text(""" -- name: get_product_by_id SELECT id, name, price FROM products WHERE id = :product_id; @@ -62,11 +59,11 @@ def temp_directory_structure() -> Generator[Path, None, None]: SELECT * FROM products WHERE category_id = :category_id; """) - (base_path / "README.md").write_text("# Test Documentation") - (base_path / "config.json").write_text('{"setting": "value"}') - (base_path / "queries" / ".gitkeep").write_text("") + (base_path / "README.md").write_text("# Test Documentation") + (base_path / "config.json").write_text('{"setting": "value"}') + (base_path / "queries" / ".gitkeep").write_text("") - yield base_path + return base_path def test_load_single_file(temp_directory_structure: Path) -> None: @@ -129,34 +126,32 @@ def test_load_parent_directory_with_namespaces(temp_directory_structure: Path) - assert "products.list_products_by_category" in queries -def test_empty_directory_handling() -> None: +def test_empty_directory_handling(tmp_path: Path) -> None: """Test handling of empty directories.""" - with tempfile.TemporaryDirectory() as temp_dir: - empty_dir = Path(temp_dir) / "empty" - empty_dir.mkdir() + empty_dir = tmp_path / "empty" + empty_dir.mkdir() - loader = SQLFileLoader() + loader = SQLFileLoader() - loader.load_sql(empty_dir) + loader.load_sql(empty_dir) - assert loader.list_queries() == [] - assert loader.list_files() == [] + assert loader.list_queries() == [] + assert loader.list_files() == [] -def test_directory_with_only_non_sql_files() -> None: +def test_directory_with_only_non_sql_files(tmp_path: Path) -> None: """Test directory containing only non-SQL files.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - (base_path / "README.md").write_text("# Documentation") - (base_path / "config.json").write_text('{"key": "value"}') - (base_path / "script.py").write_text("print('hello')") + (base_path / "README.md").write_text("# Documentation") + (base_path / "config.json").write_text('{"key": "value"}') + (base_path / "script.py").write_text("print('hello')") - loader = SQLFileLoader() - loader.load_sql(base_path) + loader = SQLFileLoader() + loader.load_sql(base_path) - assert loader.list_queries() == [] - assert loader.list_files() == [] + assert loader.list_queries() == [] + assert loader.list_files() == [] def test_mixed_file_and_directory_loading(temp_directory_structure: Path) -> None: @@ -176,205 +171,194 @@ def test_mixed_file_and_directory_loading(temp_directory_structure: Path) -> Non assert "list_active_users" in queries -def test_simple_namespace_generation() -> None: +def test_simple_namespace_generation(tmp_path: Path) -> None: """Test simple directory-to-namespace conversion.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - (base_path / "analytics").mkdir() - (base_path / "analytics" / "reports.sql").write_text(""" + (base_path / "analytics").mkdir() + (base_path / "analytics" / "reports.sql").write_text(""" -- name: user_report SELECT COUNT(*) FROM users; """) - loader = SQLFileLoader() - loader.load_sql(base_path) + loader = SQLFileLoader() + loader.load_sql(base_path) - queries = loader.list_queries() - assert "analytics.user_report" in queries + queries = loader.list_queries() + assert "analytics.user_report" in queries -def test_deep_namespace_generation() -> None: +def test_deep_namespace_generation(tmp_path: Path) -> None: """Test deep directory structure namespace generation.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - deep_path = base_path / "level1" / "level2" / "level3" - deep_path.mkdir(parents=True) + deep_path = base_path / "level1" / "level2" / "level3" + deep_path.mkdir(parents=True) - (deep_path / "deep_queries.sql").write_text(""" + (deep_path / "deep_queries.sql").write_text(""" -- name: deeply_nested_query SELECT 'deep' as level; """) - loader = SQLFileLoader() - loader.load_sql(base_path) + loader = SQLFileLoader() + loader.load_sql(base_path) - queries = loader.list_queries() - assert "level1.level2.level3.deeply_nested_query" in queries + queries = loader.list_queries() + assert "level1.level2.level3.deeply_nested_query" in queries -def test_namespace_with_special_characters() -> None: +def test_namespace_with_special_characters(tmp_path: Path) -> None: """Test namespace generation with special directory names.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - (base_path / "user-analytics").mkdir() - (base_path / "user-analytics" / "daily_reports.sql").write_text(""" + (base_path / "user-analytics").mkdir() + (base_path / "user-analytics" / "daily_reports.sql").write_text(""" -- name: daily_user_count SELECT COUNT(*) FROM users; """) - loader = SQLFileLoader() - loader.load_sql(base_path) + loader = SQLFileLoader() + loader.load_sql(base_path) - queries = loader.list_queries() + queries = loader.list_queries() - assert "user-analytics.daily_user_count" in queries + assert "user-analytics.daily_user_count" in queries -def test_no_namespace_for_root_files() -> None: +def test_no_namespace_for_root_files(tmp_path: Path) -> None: """Test that root-level files don't get namespaces.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - (base_path / "root_query.sql").write_text(""" + (base_path / "root_query.sql").write_text(""" -- name: root_level_query SELECT 'root' as level; """) - loader = SQLFileLoader() - loader.load_sql(base_path) + loader = SQLFileLoader() + loader.load_sql(base_path) - queries = loader.list_queries() + queries = loader.list_queries() - assert "root_level_query" in queries - assert "root_level_query" not in [q for q in queries if "." in q] + assert "root_level_query" in queries + assert "root_level_query" not in [q for q in queries if "." in q] -def test_sql_extension_filtering() -> None: +def test_sql_extension_filtering(tmp_path: Path) -> None: """Test that only .sql files are processed.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - (base_path / "valid.sql").write_text(""" + (base_path / "valid.sql").write_text(""" -- name: valid_query SELECT 1; """) - (base_path / "invalid.txt").write_text(""" + (base_path / "invalid.txt").write_text(""" -- name: invalid_query SELECT 2; """) - (base_path / "also_invalid.py").write_text("# Not a SQL file") + (base_path / "also_invalid.py").write_text("# Not a SQL file") - loader = SQLFileLoader() - loader.load_sql(base_path) + loader = SQLFileLoader() + loader.load_sql(base_path) - queries = loader.list_queries() - assert "valid_query" in queries - assert len(queries) == 1 + queries = loader.list_queries() + assert "valid_query" in queries + assert len(queries) == 1 -def test_hidden_file_inclusion() -> None: +def test_hidden_file_inclusion(tmp_path: Path) -> None: """Test that hidden files (starting with .) are currently included.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - (base_path / "visible.sql").write_text(""" + (base_path / "visible.sql").write_text(""" -- name: visible_query SELECT 1; """) - (base_path / ".hidden.sql").write_text(""" + (base_path / ".hidden.sql").write_text(""" -- name: hidden_query SELECT 2; """) - loader = SQLFileLoader() - loader.load_sql(base_path) + loader = SQLFileLoader() + loader.load_sql(base_path) - queries = loader.list_queries() - assert "visible_query" in queries + queries = loader.list_queries() + assert "visible_query" in queries - assert "hidden_query" in queries - assert len(queries) == 2 + assert "hidden_query" in queries + assert len(queries) == 2 -def test_recursive_pattern_matching() -> None: +def test_recursive_pattern_matching(tmp_path: Path) -> None: """Test recursive pattern matching across directory levels.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - (base_path / "level1").mkdir() - (base_path / "level1" / "level2").mkdir() + (base_path / "level1").mkdir() + (base_path / "level1" / "level2").mkdir() - (base_path / "level1" / "query1.sql").write_text(""" + (base_path / "level1" / "query1.sql").write_text(""" -- name: query_level1 SELECT 1; """) - (base_path / "level1" / "level2" / "query2.sql").write_text(""" + (base_path / "level1" / "level2" / "query2.sql").write_text(""" -- name: query_level2 SELECT 2; """) - (base_path / "level1" / "level2" / "not_sql.txt").write_text("Not SQL") + (base_path / "level1" / "level2" / "not_sql.txt").write_text("Not SQL") - loader = SQLFileLoader() - loader.load_sql(base_path) + loader = SQLFileLoader() + loader.load_sql(base_path) - queries = loader.list_queries() - assert "level1.query_level1" in queries - assert "level1.level2.query_level2" in queries - assert len(queries) == 2 + queries = loader.list_queries() + assert "level1.query_level1" in queries + assert "level1.level2.query_level2" in queries + assert len(queries) == 2 -def test_file_uri_loading() -> None: +def test_file_uri_loading(tmp_path: Path) -> None: """Test loading SQL files using file:// URIs.""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".sql", delete=False) as tf: - tf.write(""" + sql_file = tmp_path / "uri_test.sql" + sql_file.write_text(""" -- name: uri_query SELECT 'loaded from URI' as source; """) - tf.flush() - - loader = SQLFileLoader() - file_uri = f"file://{tf.name}" - loader.load_sql(file_uri) + loader = SQLFileLoader() + file_uri = f"file://{sql_file}" - queries = loader.list_queries() - assert "uri_query" in queries + loader.load_sql(file_uri) - sql = loader.get_sql("uri_query") - assert "loaded from URI" in sql.sql + queries = loader.list_queries() + assert "uri_query" in queries - Path(tf.name).unlink() + sql = loader.get_sql("uri_query") + assert "loaded from URI" in sql.sql -def test_mixed_local_and_uri_loading() -> None: +def test_mixed_local_and_uri_loading(tmp_path: Path) -> None: """Test loading both local files and URIs together.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - local_file = base_path / "local.sql" - local_file.write_text(""" + local_file = base_path / "local.sql" + local_file.write_text(""" -- name: local_query SELECT 'local' as source; """) - uri_file = base_path / "uri_file.sql" - uri_file.write_text(""" + uri_file = base_path / "uri_file.sql" + uri_file.write_text(""" -- name: uri_query SELECT 'uri' as source; """) - loader = SQLFileLoader() + loader = SQLFileLoader() - file_uri = f"file://{uri_file}" - loader.load_sql(local_file, file_uri) + file_uri = f"file://{uri_file}" + loader.load_sql(local_file, file_uri) - queries = loader.list_queries() - assert "local_query" in queries - assert "uri_query" in queries - assert len(queries) == 2 + queries = loader.list_queries() + assert "local_query" in queries + assert "uri_query" in queries + assert len(queries) == 2 def test_invalid_uri_handling() -> None: @@ -399,227 +383,210 @@ def test_nonexistent_directory_error() -> None: assert loader.list_files() == [] -def test_sql_file_without_named_statements_skipped() -> None: +def test_sql_file_without_named_statements_skipped(tmp_path: Path) -> None: """Test that SQL files without named statements are gracefully skipped.""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".sql", delete=False) as tf: - tf.write("SELECT * FROM users; -- No name comment") - tf.flush() + sql_file = tmp_path / "no_names.sql" + sql_file.write_text("SELECT * FROM users; -- No name comment") - loader = SQLFileLoader() - loader.load_sql(tf.name) - - assert len(loader.list_queries()) == 0 - assert str(tf.name) not in loader._files # pyright: ignore + loader = SQLFileLoader() + loader.load_sql(str(sql_file)) - Path(tf.name).unlink() + assert len(loader.list_queries()) == 0 + assert str(sql_file) not in loader._files # pyright: ignore -def test_duplicate_queries_across_files_error() -> None: +def test_duplicate_queries_across_files_error(tmp_path: Path) -> None: """Test error handling for duplicate query names across files.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - file1 = base_path / "file1.sql" - file1.write_text(""" + file1 = base_path / "file1.sql" + file1.write_text(""" -- name: duplicate_query SELECT 'from file1' as source; """) - file2 = base_path / "file2.sql" - file2.write_text(""" + file2 = base_path / "file2.sql" + file2.write_text(""" -- name: duplicate_query SELECT 'from file2' as source; """) - loader = SQLFileLoader() + loader = SQLFileLoader() - loader.load_sql(file1) + loader.load_sql(file1) - with pytest.raises(SQLFileParseError) as exc_info: - loader.load_sql(file2) + with pytest.raises(SQLFileParseError) as exc_info: + loader.load_sql(file2) - assert "already exists" in str(exc_info.value) + assert "already exists" in str(exc_info.value) -def test_encoding_error_handling() -> None: +def test_encoding_error_handling(tmp_path: Path) -> None: """Test handling of encoding errors.""" - with tempfile.NamedTemporaryFile(mode="wb", suffix=".sql", delete=False) as tf: - tf.write(b"\xff\xfe-- name: test\nSELECT 1;") - tf.flush() - - loader = SQLFileLoader(encoding="utf-8") + sql_file = tmp_path / "bad_encoding.sql" + sql_file.write_bytes(b"\xff\xfe-- name: test\nSELECT 1;") - with pytest.raises(SQLFileParseError): - loader.load_sql(tf.name) + loader = SQLFileLoader(encoding="utf-8") - Path(tf.name).unlink() + with pytest.raises(SQLFileParseError): + loader.load_sql(str(sql_file)) -def test_large_file_handling() -> None: +def test_large_file_handling(tmp_path: Path) -> None: """Test handling of large SQL files.""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".sql", delete=False) as tf: - content = [ - f""" + sql_file = tmp_path / "large.sql" + content = [ + f""" -- name: query_{i:03d} SELECT {i} as query_number, 'data_{i}' as data FROM large_table WHERE id > {i * 100} LIMIT 1000; """ - for i in range(100) - ] - - tf.write("\n".join(content)) - tf.flush() + for i in range(100) + ] - loader = SQLFileLoader() + sql_file.write_text("\n".join(content)) - loader.load_sql(tf.name) + loader = SQLFileLoader() - queries = loader.list_queries() - assert len(queries) == 100 + loader.load_sql(str(sql_file)) - assert "query_000" in queries - assert "query_050" in queries - assert "query_099" in queries + queries = loader.list_queries() + assert len(queries) == 100 - Path(tf.name).unlink() + assert "query_000" in queries + assert "query_050" in queries + assert "query_099" in queries -def test_deep_directory_structure_performance() -> None: +def test_deep_directory_structure_performance(tmp_path: Path) -> None: """Test performance with deep directory structures.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - current_path = base_path - for i in range(10): - current_path = current_path / f"level_{i}" - current_path.mkdir() + current_path = base_path + for i in range(10): + current_path = current_path / f"level_{i}" + current_path.mkdir() - sql_file = current_path / f"queries_level_{i}.sql" - sql_file.write_text( - f""" + sql_file = current_path / f"queries_level_{i}.sql" + sql_file.write_text( + f""" -- name: query_at_level_{i} SELECT {i} as level_number; """ - ) + ) - loader = SQLFileLoader() + loader = SQLFileLoader() - loader.load_sql(base_path) + loader.load_sql(base_path) - queries = loader.list_queries() - assert len(queries) == 10 + queries = loader.list_queries() + assert len(queries) == 10 - deepest_query = ( - "level_0.level_1.level_2.level_3.level_4.level_5.level_6.level_7.level_8.level_9.query_at_level_9" - ) - assert deepest_query in queries + deepest_query = "level_0.level_1.level_2.level_3.level_4.level_5.level_6.level_7.level_8.level_9.query_at_level_9" + assert deepest_query in queries -def test_concurrent_loading_safety() -> None: +def test_concurrent_loading_safety(tmp_path: Path) -> None: """Test thread safety during concurrent loading operations.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - for i in range(5): - sql_file = base_path / f"concurrent_{i}.sql" - sql_file.write_text( - f""" + for i in range(5): + sql_file = base_path / f"concurrent_{i}.sql" + sql_file.write_text( + f""" -- name: concurrent_query_{i} SELECT {i} as concurrent_id; """ - ) + ) - loader = SQLFileLoader() + loader = SQLFileLoader() - for i in range(5): - sql_file = base_path / f"concurrent_{i}.sql" - loader.load_sql(sql_file) + for i in range(5): + sql_file = base_path / f"concurrent_{i}.sql" + loader.load_sql(sql_file) - queries = loader.list_queries() - assert len(queries) == 5 + queries = loader.list_queries() + assert len(queries) == 5 - for i in range(5): - assert f"concurrent_query_{i}" in queries + for i in range(5): + assert f"concurrent_query_{i}" in queries -def test_symlink_handling() -> None: +def test_symlink_handling(tmp_path: Path) -> None: """Test handling of symbolic links.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - original_file = base_path / "original.sql" - original_file.write_text( - """ + original_file = base_path / "original.sql" + original_file.write_text( + """ -- name: symlinked_query SELECT 'original' as source; """ - ) + ) - symlink_file = base_path / "symlinked.sql" - try: - symlink_file.symlink_to(original_file) - except OSError: - pytest.skip("Symbolic links not supported on this system") + symlink_file = base_path / "symlinked.sql" + try: + symlink_file.symlink_to(original_file) + except OSError: + pytest.skip("Symbolic links not supported on this system") - loader = SQLFileLoader() - loader.load_sql(symlink_file) + loader = SQLFileLoader() + loader.load_sql(symlink_file) - queries = loader.list_queries() - assert "symlinked_query" in queries + queries = loader.list_queries() + assert "symlinked_query" in queries -def test_case_sensitivity_handling() -> None: +def test_case_sensitivity_handling(tmp_path: Path) -> None: """Test handling of case-sensitive file systems.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - (base_path / "Queries.SQL").write_text( - """ + (base_path / "Queries.SQL").write_text( + """ -- name: uppercase_extension_query SELECT 'UPPERCASE' as extension_type; """ - ) + ) - (base_path / "queries.sql").write_text( - """ + (base_path / "queries.sql").write_text( + """ -- name: lowercase_extension_query SELECT 'lowercase' as extension_type; """ - ) + ) - loader = SQLFileLoader() - loader.load_sql(base_path) + loader = SQLFileLoader() + loader.load_sql(base_path) - queries = loader.list_queries() + queries = loader.list_queries() - assert len(queries) >= 1 - assert "lowercase_extension_query" in queries or "uppercase_extension_query" in queries + assert len(queries) >= 1 + assert "lowercase_extension_query" in queries or "uppercase_extension_query" in queries -def test_unicode_filename_handling() -> None: +def test_unicode_filename_handling(tmp_path: Path) -> None: """Test handling of Unicode filenames.""" - with tempfile.TemporaryDirectory() as temp_dir: - base_path = Path(temp_dir) + base_path = tmp_path - unicode_file = base_path / "测试_файл_query.sql" - try: - unicode_file.write_text( - """ + unicode_file = base_path / "测试_файл_query.sql" + try: + unicode_file.write_text( + """ -- name: unicode_filename_query SELECT 'Unicode filename support' as message; """, - encoding="utf-8", - ) - except OSError: - pytest.skip("Unicode filenames not supported on this system") + encoding="utf-8", + ) + except OSError: + pytest.skip("Unicode filenames not supported on this system") - loader = SQLFileLoader() - loader.load_sql(base_path) + loader = SQLFileLoader() + loader.load_sql(base_path) - queries = loader.list_queries() - assert "unicode_filename_query" in queries + queries = loader.list_queries() + assert "unicode_filename_query" in queries @pytest.fixture diff --git a/tests/unit/test_loader/test_sql_file_loader.py b/tests/unit/test_loader/test_sql_file_loader.py index 7de136c5..48ab1f46 100644 --- a/tests/unit/test_loader/test_sql_file_loader.py +++ b/tests/unit/test_loader/test_sql_file_loader.py @@ -9,7 +9,6 @@ - Parameter style detection and preservation """ -import tempfile import time from pathlib import Path from unittest.mock import Mock, patch @@ -342,40 +341,34 @@ def test_generate_file_cache_key() -> None: assert len(key1.split(":")[1]) == 16 -def test_calculate_file_checksum() -> None: +def test_calculate_file_checksum(tmp_path: Path) -> None: """Test file checksum calculation.""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".sql", delete=False) as tf: - tf.write("SELECT * FROM users;") - tf.flush() + sql_file = tmp_path / "test.sql" + sql_file.write_text("SELECT * FROM users;") - loader = SQLFileLoader() - checksum = loader._calculate_file_checksum(tf.name) - - assert isinstance(checksum, str) - assert len(checksum) == 32 + loader = SQLFileLoader() + checksum = loader._calculate_file_checksum(str(sql_file)) - Path(tf.name).unlink() + assert isinstance(checksum, str) + assert len(checksum) == 32 -def test_is_file_unchanged() -> None: +def test_is_file_unchanged(tmp_path: Path) -> None: """Test file change detection.""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".sql", delete=False) as tf: - original_content = "SELECT * FROM users;" - tf.write(original_content) - tf.flush() + sql_file = tmp_path / "test.sql" + original_content = "SELECT * FROM users;" + sql_file.write_text(original_content) - loader = SQLFileLoader() - - sql_file = SQLFile(original_content, tf.name) - cached_file = CachedSQLFile(sql_file, {}) + loader = SQLFileLoader() - assert loader._is_file_unchanged(tf.name, cached_file) + sql_file_obj = SQLFile(original_content, str(sql_file)) + cached_file = CachedSQLFile(sql_file_obj, {}) - Path(tf.name).write_text("SELECT * FROM products;") + assert loader._is_file_unchanged(str(sql_file), cached_file) - assert not loader._is_file_unchanged(tf.name, cached_file) + sql_file.write_text("SELECT * FROM products;") - Path(tf.name).unlink() + assert not loader._is_file_unchanged(str(sql_file), cached_file) def test_add_named_sql() -> None: diff --git a/tests/unit/test_migrations/test_extension_discovery.py b/tests/unit/test_migrations/test_extension_discovery.py index 596ce89c..c9afa8d8 100644 --- a/tests/unit/test_migrations/test_extension_discovery.py +++ b/tests/unit/test_migrations/test_extension_discovery.py @@ -1,84 +1,62 @@ """Test extension migration discovery functionality.""" -import tempfile from pathlib import Path from sqlspec.adapters.sqlite.config import SqliteConfig from sqlspec.migrations.commands import SyncMigrationCommands -def test_extension_migration_discovery() -> None: +def test_extension_migration_discovery(tmp_path: Path) -> None: """Test that extension migrations are discovered when configured.""" - with tempfile.TemporaryDirectory() as temp_dir: - # Create config with extension migrations enabled - config = SqliteConfig( - pool_config={"database": ":memory:"}, - migration_config={ - "script_location": str(temp_dir), - "version_table_name": "test_migrations", - "include_extensions": ["litestar"], - }, - ) - - # Create migration commands - commands = SyncMigrationCommands(config) - - # Check that extension migrations were discovered - assert hasattr(commands, "runner") - assert hasattr(commands.runner, "extension_migrations") - - # Should have discovered Litestar migrations directory if it exists - if "litestar" in commands.runner.extension_migrations: - litestar_path = commands.runner.extension_migrations["litestar"] - assert litestar_path.exists() - assert litestar_path.name == "migrations" - - -def test_extension_migration_context() -> None: + config = SqliteConfig( + pool_config={"database": ":memory:"}, + migration_config={ + "script_location": str(tmp_path), + "version_table_name": "test_migrations", + "include_extensions": ["litestar"], + }, + ) + + commands = SyncMigrationCommands(config) + + assert hasattr(commands, "runner") + assert hasattr(commands.runner, "extension_migrations") + + if "litestar" in commands.runner.extension_migrations: + litestar_path = commands.runner.extension_migrations["litestar"] + assert litestar_path.exists() + assert litestar_path.name == "migrations" + + +def test_extension_migration_context(tmp_path: Path) -> None: """Test that migration context is created with dialect information.""" - with tempfile.TemporaryDirectory() as temp_dir: - # Create config with known dialect - config = SqliteConfig( - pool_config={"database": ":memory:"}, - migration_config={"script_location": str(temp_dir), "include_extensions": ["litestar"]}, - ) + config = SqliteConfig( + pool_config={"database": ":memory:"}, + migration_config={"script_location": str(tmp_path), "include_extensions": ["litestar"]}, + ) - # Create migration commands - this should create context - commands = SyncMigrationCommands(config) + commands = SyncMigrationCommands(config) - # The runner should have a context with dialect - assert hasattr(commands.runner, "context") - assert commands.runner.context is not None - assert commands.runner.context.dialect == "sqlite" + assert hasattr(commands.runner, "context") + assert commands.runner.context is not None + assert commands.runner.context.dialect == "sqlite" -def test_no_extensions_by_default() -> None: +def test_no_extensions_by_default(tmp_path: Path) -> None: """Test that no extension migrations are included by default.""" - with tempfile.TemporaryDirectory() as temp_dir: - # Create config without extension migrations - config = SqliteConfig( - pool_config={"database": ":memory:"}, - migration_config={ - "script_location": str(temp_dir) - # No include_extensions key - }, - ) + config = SqliteConfig(pool_config={"database": ":memory:"}, migration_config={"script_location": str(tmp_path)}) - # Create migration commands - commands = SyncMigrationCommands(config) + commands = SyncMigrationCommands(config) - # Should have no extension migrations - assert commands.runner.extension_migrations == {} + assert commands.runner.extension_migrations == {} -def test_migration_file_discovery_with_extensions() -> None: +def test_migration_file_discovery_with_extensions(tmp_path: Path) -> None: """Test that migration files are discovered from both primary and extension paths.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_dir = Path(temp_dir) + migrations_dir = tmp_path - # Create a primary migration - primary_migration = migrations_dir / "0002_user_table.sql" - primary_migration.write_text(""" + primary_migration = migrations_dir / "0002_user_table.sql" + primary_migration.write_text(""" -- name: migrate-0002-up CREATE TABLE users (id INTEGER); @@ -86,23 +64,15 @@ def test_migration_file_discovery_with_extensions() -> None: DROP TABLE users; """) - # Create config with extension migrations - config = SqliteConfig( - pool_config={"database": ":memory:"}, - migration_config={"script_location": str(migrations_dir), "include_extensions": ["litestar"]}, - ) - - # Create migration commands - commands = SyncMigrationCommands(config) + config = SqliteConfig( + pool_config={"database": ":memory:"}, + migration_config={"script_location": str(migrations_dir), "include_extensions": ["litestar"]}, + ) - # Get all migration files - migration_files = commands.runner.get_migration_files() + commands = SyncMigrationCommands(config) - # Should have both primary and extension migrations - versions = [version for version, _ in migration_files] + migration_files = commands.runner.get_migration_files() - # Primary migration - assert "0002" in versions + versions = [version for version, _ in migration_files] - # Extension migrations should be prefixed (if any exist) - # Note: Extension migrations only exist when specific extension features are available + assert "0002" in versions diff --git a/tests/unit/test_migrations/test_migration.py b/tests/unit/test_migrations/test_migration.py index 66198f4b..0732456f 100644 --- a/tests/unit/test_migrations/test_migration.py +++ b/tests/unit/test_migrations/test_migration.py @@ -10,9 +10,6 @@ - Error handling and validation """ -from __future__ import annotations - -import tempfile from pathlib import Path from typing import Any from unittest.mock import Mock, patch @@ -124,14 +121,13 @@ def test_calculate_checksum_unicode_content() -> None: assert len(checksum) == 32 -def test_get_migration_files_sync_empty_directory() -> None: +def test_get_migration_files_sync_empty_directory(tmp_path: Path) -> None: """Test getting migration files from empty directory.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) - runner = MockMigrationRunner(migrations_path) + migrations_path = tmp_path + runner = MockMigrationRunner(migrations_path) - files = runner._get_migration_files_sync() - assert files == [] + files = runner._get_migration_files_sync() + assert files == [] def test_get_migration_files_sync_nonexistent_directory() -> None: @@ -143,88 +139,83 @@ def test_get_migration_files_sync_nonexistent_directory() -> None: assert files == [] -def test_get_migration_files_sync_with_sql_files() -> None: +def test_get_migration_files_sync_with_sql_files(tmp_path: Path) -> None: """Test getting migration files with SQL files.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + migrations_path = tmp_path - (migrations_path / "0001_initial.sql").write_text("-- Initial migration") - (migrations_path / "0003_add_indexes.sql").write_text("-- Add indexes") - (migrations_path / "0002_add_users.sql").write_text("-- Add users table") + (migrations_path / "0001_initial.sql").write_text("-- Initial migration") + (migrations_path / "0003_add_indexes.sql").write_text("-- Add indexes") + (migrations_path / "0002_add_users.sql").write_text("-- Add users table") - (migrations_path / "README.md").write_text("# Migrations") - (migrations_path / "config.json").write_text("{}") + (migrations_path / "README.md").write_text("# Migrations") + (migrations_path / "config.json").write_text("{}") - runner = MockMigrationRunner(migrations_path) - files = runner._get_migration_files_sync() + runner = MockMigrationRunner(migrations_path) + files = runner._get_migration_files_sync() - assert len(files) == 3 - assert files[0][0] == "0001" - assert files[1][0] == "0002" - assert files[2][0] == "0003" + assert len(files) == 3 + assert files[0][0] == "0001" + assert files[1][0] == "0002" + assert files[2][0] == "0003" - assert files[0][1].name == "0001_initial.sql" - assert files[1][1].name == "0002_add_users.sql" - assert files[2][1].name == "0003_add_indexes.sql" + assert files[0][1].name == "0001_initial.sql" + assert files[1][1].name == "0002_add_users.sql" + assert files[2][1].name == "0003_add_indexes.sql" -def test_get_migration_files_sync_with_python_files() -> None: +def test_get_migration_files_sync_with_python_files(tmp_path: Path) -> None: """Test getting migration files with Python files.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + migrations_path = tmp_path - (migrations_path / "0001_initial.py").write_text("# Initial migration") - (migrations_path / "0002_data_migration.py").write_text("# Data migration") + (migrations_path / "0001_initial.py").write_text("# Initial migration") + (migrations_path / "0002_data_migration.py").write_text("# Data migration") - runner = MockMigrationRunner(migrations_path) - files = runner._get_migration_files_sync() + runner = MockMigrationRunner(migrations_path) + files = runner._get_migration_files_sync() - assert len(files) == 2 - assert files[0][0] == "0001" - assert files[1][0] == "0002" + assert len(files) == 2 + assert files[0][0] == "0001" + assert files[1][0] == "0002" -def test_get_migration_files_sync_mixed_types() -> None: +def test_get_migration_files_sync_mixed_types(tmp_path: Path) -> None: """Test getting migration files with mixed SQL and Python files.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + migrations_path = tmp_path - (migrations_path / "0001_initial.sql").write_text("-- SQL migration") - (migrations_path / "0002_data_migration.py").write_text("# Python migration") - (migrations_path / "0003_add_indexes.sql").write_text("-- Another SQL migration") + (migrations_path / "0001_initial.sql").write_text("-- SQL migration") + (migrations_path / "0002_data_migration.py").write_text("# Python migration") + (migrations_path / "0003_add_indexes.sql").write_text("-- Another SQL migration") - runner = MockMigrationRunner(migrations_path) - files = runner._get_migration_files_sync() + runner = MockMigrationRunner(migrations_path) + files = runner._get_migration_files_sync() - assert len(files) == 3 - assert files[0][0] == "0001" - assert files[1][0] == "0002" - assert files[2][0] == "0003" + assert len(files) == 3 + assert files[0][0] == "0001" + assert files[1][0] == "0002" + assert files[2][0] == "0003" -def test_get_migration_files_sync_hidden_files_ignored() -> None: +def test_get_migration_files_sync_hidden_files_ignored(tmp_path: Path) -> None: """Test that hidden files are ignored.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + migrations_path = tmp_path - (migrations_path / "0001_visible.sql").write_text("-- Visible migration") - (migrations_path / ".0002_hidden.sql").write_text("-- Hidden migration") - (migrations_path / ".gitkeep").write_text("") + (migrations_path / "0001_visible.sql").write_text("-- Visible migration") + (migrations_path / ".0002_hidden.sql").write_text("-- Hidden migration") + (migrations_path / ".gitkeep").write_text("") - runner = MockMigrationRunner(migrations_path) - files = runner._get_migration_files_sync() + runner = MockMigrationRunner(migrations_path) + files = runner._get_migration_files_sync() - assert len(files) == 1 - assert files[0][1].name == "0001_visible.sql" + assert len(files) == 1 + assert files[0][1].name == "0001_visible.sql" -def test_load_migration_metadata_sql_file() -> None: +def test_load_migration_metadata_sql_file(tmp_path: Path) -> None: """Test loading metadata from SQL migration file.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + migrations_path = tmp_path - migration_file = migrations_path / "0001_create_users.sql" - migration_content = """ + migration_file = migrations_path / "0001_create_users.sql" + migration_content = """ -- name: migrate-0001-up CREATE TABLE users ( id INTEGER PRIMARY KEY, @@ -235,37 +226,36 @@ def test_load_migration_metadata_sql_file() -> None: -- name: migrate-0001-down DROP TABLE users; """ - migration_file.write_text(migration_content) + migration_file.write_text(migration_content) - runner = MockMigrationRunner(migrations_path) + runner = MockMigrationRunner(migrations_path) - runner.loader.clear_cache = Mock() - runner.loader.load_sql = Mock() - runner.loader.has_query = Mock(side_effect=lambda query: True) + runner.loader.clear_cache = Mock() + runner.loader.load_sql = Mock() + runner.loader.has_query = Mock(side_effect=lambda query: True) - with patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader: - mock_loader = Mock() - mock_loader.validate_migration_file = Mock() - mock_get_loader.return_value = mock_loader + with patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader: + mock_loader = Mock() + mock_loader.validate_migration_file = Mock() + mock_get_loader.return_value = mock_loader - metadata = runner._load_migration_metadata(migration_file) + metadata = runner._load_migration_metadata(migration_file) - assert metadata["version"] == "0001" - assert metadata["description"] == "create_users" - assert metadata["file_path"] == migration_file - assert metadata["has_upgrade"] is True - assert metadata["has_downgrade"] is True - assert isinstance(metadata["checksum"], str) - assert len(metadata["checksum"]) == 32 + assert metadata["version"] == "0001" + assert metadata["description"] == "create_users" + assert metadata["file_path"] == migration_file + assert metadata["has_upgrade"] is True + assert metadata["has_downgrade"] is True + assert isinstance(metadata["checksum"], str) + assert len(metadata["checksum"]) == 32 -def test_load_migration_metadata_python_file_sync() -> None: +def test_load_migration_metadata_python_file_sync(tmp_path: Path) -> None: """Test loading metadata from Python migration file with sync functions.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + migrations_path = tmp_path - migration_file = migrations_path / "0001_data_migration.py" - migration_content = ''' + migration_file = migrations_path / "0001_data_migration.py" + migration_content = ''' def up(): """Upgrade migration.""" return ["INSERT INTO users (name, email) VALUES ('admin', 'admin@example.com');"] @@ -274,38 +264,37 @@ def down(): """Downgrade migration.""" return ["DELETE FROM users WHERE name = 'admin';"] ''' - migration_file.write_text(migration_content) + migration_file.write_text(migration_content) - runner = MockMigrationRunner(migrations_path) + runner = MockMigrationRunner(migrations_path) - with ( - patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader, - patch("sqlspec.migrations.base.await_") as mock_await, - ): - mock_loader = Mock() - mock_loader.validate_migration_file = Mock() - mock_loader.get_up_sql = Mock() - mock_loader.get_down_sql = Mock() - mock_get_loader.return_value = mock_loader + with ( + patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader, + patch("sqlspec.migrations.base.await_") as mock_await, + ): + mock_loader = Mock() + mock_loader.validate_migration_file = Mock() + mock_loader.get_up_sql = Mock() + mock_loader.get_down_sql = Mock() + mock_get_loader.return_value = mock_loader - mock_await.return_value = Mock(return_value=True) + mock_await.return_value = Mock(return_value=True) - metadata = runner._load_migration_metadata(migration_file) + metadata = runner._load_migration_metadata(migration_file) - assert metadata["version"] == "0001" - assert metadata["description"] == "data_migration" - assert metadata["file_path"] == migration_file - assert metadata["has_upgrade"] is True - assert metadata["has_downgrade"] is True + assert metadata["version"] == "0001" + assert metadata["description"] == "data_migration" + assert metadata["file_path"] == migration_file + assert metadata["has_upgrade"] is True + assert metadata["has_downgrade"] is True -def test_load_migration_metadata_python_file_async() -> None: +def test_load_migration_metadata_python_file_async(tmp_path: Path) -> None: """Test loading metadata from Python migration file with async functions.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + migrations_path = tmp_path - migration_file = migrations_path / "0001_async_migration.py" - migration_content = ''' + migration_file = migrations_path / "0001_async_migration.py" + migration_content = ''' import asyncio async def up(): @@ -318,38 +307,37 @@ async def down(): await asyncio.sleep(0.001) return ["DELETE FROM users WHERE name = 'admin';"] ''' - migration_file.write_text(migration_content) + migration_file.write_text(migration_content) - runner = MockMigrationRunner(migrations_path) + runner = MockMigrationRunner(migrations_path) - with ( - patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader, - patch("sqlspec.migrations.base.await_") as mock_await, - ): - mock_loader = Mock() - mock_loader.validate_migration_file = Mock() - mock_loader.get_up_sql = Mock() - mock_loader.get_down_sql = Mock() - mock_get_loader.return_value = mock_loader + with ( + patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader, + patch("sqlspec.migrations.base.await_") as mock_await, + ): + mock_loader = Mock() + mock_loader.validate_migration_file = Mock() + mock_loader.get_up_sql = Mock() + mock_loader.get_down_sql = Mock() + mock_get_loader.return_value = mock_loader - mock_await.return_value = Mock(return_value=True) + mock_await.return_value = Mock(return_value=True) - metadata = runner._load_migration_metadata(migration_file) + metadata = runner._load_migration_metadata(migration_file) - assert metadata["version"] == "0001" - assert metadata["description"] == "async_migration" - assert metadata["file_path"] == migration_file - assert metadata["has_upgrade"] is True - assert metadata["has_downgrade"] is True + assert metadata["version"] == "0001" + assert metadata["description"] == "async_migration" + assert metadata["file_path"] == migration_file + assert metadata["has_upgrade"] is True + assert metadata["has_downgrade"] is True -def test_load_migration_metadata_python_file_mixed() -> None: +def test_load_migration_metadata_python_file_mixed(tmp_path: Path) -> None: """Test loading metadata from Python migration file with mixed sync/async functions.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + migrations_path = tmp_path - migration_file = migrations_path / "0001_mixed_migration.py" - migration_content = ''' + migration_file = migrations_path / "0001_mixed_migration.py" + migration_content = ''' import asyncio def up(): @@ -361,38 +349,37 @@ async def down(): await asyncio.sleep(0.001) return ["DELETE FROM users WHERE name = 'admin';"] ''' - migration_file.write_text(migration_content) + migration_file.write_text(migration_content) - runner = MockMigrationRunner(migrations_path) + runner = MockMigrationRunner(migrations_path) - with ( - patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader, - patch("sqlspec.migrations.base.await_") as mock_await, - ): - mock_loader = Mock() - mock_loader.validate_migration_file = Mock() - mock_loader.get_up_sql = Mock() - mock_loader.get_down_sql = Mock() - mock_get_loader.return_value = mock_loader + with ( + patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader, + patch("sqlspec.migrations.base.await_") as mock_await, + ): + mock_loader = Mock() + mock_loader.validate_migration_file = Mock() + mock_loader.get_up_sql = Mock() + mock_loader.get_down_sql = Mock() + mock_get_loader.return_value = mock_loader - mock_await.return_value = Mock(return_value=True) + mock_await.return_value = Mock(return_value=True) - metadata = runner._load_migration_metadata(migration_file) + metadata = runner._load_migration_metadata(migration_file) - assert metadata["version"] == "0001" - assert metadata["description"] == "mixed_migration" - assert metadata["file_path"] == migration_file - assert metadata["has_upgrade"] is True - assert metadata["has_downgrade"] is True + assert metadata["version"] == "0001" + assert metadata["description"] == "mixed_migration" + assert metadata["file_path"] == migration_file + assert metadata["has_upgrade"] is True + assert metadata["has_downgrade"] is True -def test_load_multiple_mixed_migrations() -> None: +def test_load_multiple_mixed_migrations(tmp_path: Path) -> None: """Test loading multiple migrations with mixed SQL and Python (sync/async) files.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + migrations_path = tmp_path - sql_migration = migrations_path / "0001_create_tables.sql" - sql_content = """ + sql_migration = migrations_path / "0001_create_tables.sql" + sql_content = """ -- up CREATE TABLE users ( id INTEGER PRIMARY KEY, @@ -413,10 +400,10 @@ def test_load_multiple_mixed_migrations() -> None: DROP TABLE posts; DROP TABLE users; """ - sql_migration.write_text(sql_content) + sql_migration.write_text(sql_content) - python_sync_migration = migrations_path / "0002_seed_data.py" - python_sync_content = ''' + python_sync_migration = migrations_path / "0002_seed_data.py" + python_sync_content = ''' def up(): """Sync upgrade migration to seed initial data.""" return [ @@ -432,10 +419,10 @@ def down(): "DELETE FROM users WHERE email IN ('admin@example.com', 'user1@example.com');" ] ''' - python_sync_migration.write_text(python_sync_content) + python_sync_migration.write_text(python_sync_content) - python_async_migration = migrations_path / "0003_async_data_processing.py" - python_async_content = ''' + python_async_migration = migrations_path / "0003_async_data_processing.py" + python_async_content = ''' import asyncio async def up(): @@ -454,10 +441,10 @@ async def down(): "UPDATE users SET name = LOWER(name) WHERE id > 0;" ] ''' - python_async_migration.write_text(python_async_content) + python_async_migration.write_text(python_async_content) - sql_migration2 = migrations_path / "0004_add_indexes.sql" - sql_content2 = """ + sql_migration2 = migrations_path / "0004_add_indexes.sql" + sql_content2 = """ -- up CREATE INDEX idx_users_email ON users(email); CREATE INDEX idx_posts_user_id ON posts(user_id); @@ -468,10 +455,10 @@ async def down(): DROP INDEX idx_posts_user_id; DROP INDEX idx_users_email; """ - sql_migration2.write_text(sql_content2) + sql_migration2.write_text(sql_content2) - python_mixed_migration = migrations_path / "0005_mixed_operations.py" - python_mixed_content = ''' + python_mixed_migration = migrations_path / "0005_mixed_operations.py" + python_mixed_content = ''' import asyncio def up(): @@ -483,115 +470,113 @@ async def down(): await asyncio.sleep(0.001) return ["ALTER TABLE users DROP COLUMN last_login;"] ''' - python_mixed_migration.write_text(python_mixed_content) + python_mixed_migration.write_text(python_mixed_content) - runner = MockMigrationRunner(migrations_path) + runner = MockMigrationRunner(migrations_path) - runner.loader.has_query = Mock(return_value=True) - runner.loader.load_sql = Mock() - runner.loader.clear_cache = Mock() + runner.loader.has_query = Mock(return_value=True) + runner.loader.load_sql = Mock() + runner.loader.clear_cache = Mock() - migration_files = sorted(migrations_path.glob("*"), key=lambda p: p.name) + migration_files = sorted(migrations_path.glob("*"), key=lambda p: p.name) - with ( - patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader, - patch("sqlspec.migrations.base.await_") as mock_await, - ): - mock_loader = Mock() - mock_loader.validate_migration_file = Mock() - mock_loader.get_up_sql = Mock() - mock_loader.get_down_sql = Mock() - mock_get_loader.return_value = mock_loader + with ( + patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader, + patch("sqlspec.migrations.base.await_") as mock_await, + ): + mock_loader = Mock() + mock_loader.validate_migration_file = Mock() + mock_loader.get_up_sql = Mock() + mock_loader.get_down_sql = Mock() + mock_get_loader.return_value = mock_loader - mock_await.return_value = Mock(return_value=True) + mock_await.return_value = Mock(return_value=True) - all_metadata = [] - for migration_file in migration_files: - metadata = runner._load_migration_metadata(migration_file) - all_metadata.append(metadata) + all_metadata = [] + for migration_file in migration_files: + metadata = runner._load_migration_metadata(migration_file) + all_metadata.append(metadata) - assert len(all_metadata) == 5 + assert len(all_metadata) == 5 - sql_metadata = [m for m in all_metadata if m["file_path"].suffix == ".sql"] - assert len(sql_metadata) == 2 + sql_metadata = [m for m in all_metadata if m["file_path"].suffix == ".sql"] + assert len(sql_metadata) == 2 - python_metadata = [m for m in all_metadata if m["file_path"].suffix == ".py"] - assert len(python_metadata) == 3 + python_metadata = [m for m in all_metadata if m["file_path"].suffix == ".py"] + assert len(python_metadata) == 3 - expected_migrations = [ - {"version": "0001", "description": "create_tables", "type": "sql"}, - {"version": "0002", "description": "seed_data", "type": "python_sync"}, - {"version": "0003", "description": "async_data_processing", "type": "python_async"}, - {"version": "0004", "description": "add_indexes", "type": "sql"}, - {"version": "0005", "description": "mixed_operations", "type": "python_mixed"}, - ] + expected_migrations = [ + {"version": "0001", "description": "create_tables", "type": "sql"}, + {"version": "0002", "description": "seed_data", "type": "python_sync"}, + {"version": "0003", "description": "async_data_processing", "type": "python_async"}, + {"version": "0004", "description": "add_indexes", "type": "sql"}, + {"version": "0005", "description": "mixed_operations", "type": "python_mixed"}, + ] - for i, expected in enumerate(expected_migrations): - metadata = all_metadata[i] - assert metadata["version"] == expected["version"] - assert metadata["description"] == expected["description"] - assert metadata["has_upgrade"] is True - assert metadata["has_downgrade"] is True + for i, expected in enumerate(expected_migrations): + metadata = all_metadata[i] + assert metadata["version"] == expected["version"] + assert metadata["description"] == expected["description"] + assert metadata["has_upgrade"] is True + assert metadata["has_downgrade"] is True - if expected["type"] == "sql": - assert metadata["file_path"].suffix == ".sql" - else: - assert metadata["file_path"].suffix == ".py" + if expected["type"] == "sql": + assert metadata["file_path"].suffix == ".sql" + else: + assert metadata["file_path"].suffix == ".py" -def test_load_migration_metadata_no_downgrade() -> None: +def test_load_migration_metadata_no_downgrade(tmp_path: Path) -> None: """Test loading metadata when no downgrade is available.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + migrations_path = tmp_path - migration_file = migrations_path / "0001_irreversible.sql" - migration_content = """ + migration_file = migrations_path / "0001_irreversible.sql" + migration_content = """ -- name: migrate-0001-up CREATE TABLE users ( id INTEGER PRIMARY KEY, name TEXT NOT NULL ); """ - migration_file.write_text(migration_content) + migration_file.write_text(migration_content) - runner = MockMigrationRunner(migrations_path) + runner = MockMigrationRunner(migrations_path) - runner.loader.clear_cache = Mock() - runner.loader.load_sql = Mock() - runner.loader.has_query = Mock(side_effect=lambda query: query.endswith("-up")) + runner.loader.clear_cache = Mock() + runner.loader.load_sql = Mock() + runner.loader.has_query = Mock(side_effect=lambda query: query.endswith("-up")) - with patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader: - mock_loader = Mock() - mock_loader.validate_migration_file = Mock() - mock_get_loader.return_value = mock_loader + with patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader: + mock_loader = Mock() + mock_loader.validate_migration_file = Mock() + mock_get_loader.return_value = mock_loader - metadata = runner._load_migration_metadata(migration_file) + metadata = runner._load_migration_metadata(migration_file) - assert metadata["has_upgrade"] is True - assert metadata["has_downgrade"] is False + assert metadata["has_upgrade"] is True + assert metadata["has_downgrade"] is False -def test_load_migration_metadata_invalid_version() -> None: +def test_load_migration_metadata_invalid_version(tmp_path: Path) -> None: """Test loading metadata with invalid version format.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + migrations_path = tmp_path - migration_file = migrations_path / "invalid_name.sql" - migration_content = "CREATE TABLE test (id INTEGER);" - migration_file.write_text(migration_content) + migration_file = migrations_path / "invalid_name.sql" + migration_content = "CREATE TABLE test (id INTEGER);" + migration_file.write_text(migration_content) - runner = MockMigrationRunner(migrations_path) + runner = MockMigrationRunner(migrations_path) - with patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader: - mock_loader = Mock() - mock_loader.validate_migration_file = Mock() - mock_get_loader.return_value = mock_loader + with patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader: + mock_loader = Mock() + mock_loader.validate_migration_file = Mock() + mock_get_loader.return_value = mock_loader - metadata = runner._load_migration_metadata(migration_file) + metadata = runner._load_migration_metadata(migration_file) - assert metadata["version"] is None + assert metadata["version"] is None - assert metadata["description"] == "name" + assert metadata["description"] == "name" def test_get_migration_sql_upgrade() -> None: diff --git a/tests/unit/test_migrations/test_migration_commands.py b/tests/unit/test_migrations/test_migration_commands.py index ec0c03a7..1d8051fd 100644 --- a/tests/unit/test_migrations/test_migration_commands.py +++ b/tests/unit/test_migrations/test_migration_commands.py @@ -10,7 +10,6 @@ - Command routing and parameter passing """ -import tempfile from pathlib import Path from unittest.mock import AsyncMock, patch @@ -49,33 +48,31 @@ def test_migration_commands_async_config_initialization(async_config: AiosqliteC assert hasattr(commands, "runner") -def test_migration_commands_sync_init_delegation(sync_config: SqliteConfig) -> None: +def test_migration_commands_sync_init_delegation(tmp_path: Path, sync_config: SqliteConfig) -> None: """Test that sync config init is delegated directly to sync implementation.""" with patch.object(SyncMigrationCommands, "init") as mock_init: commands = SyncMigrationCommands(sync_config) - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = str(Path(temp_dir) / "migrations") + migration_dir = str(tmp_path / "migrations") - commands.init(migration_dir, package=False) + commands.init(migration_dir, package=False) - mock_init.assert_called_once_with(migration_dir, package=False) + mock_init.assert_called_once_with(migration_dir, package=False) -async def test_migration_commands_async_init_delegation(async_config: AiosqliteConfig) -> None: +async def test_migration_commands_async_init_delegation(tmp_path: Path, async_config: AiosqliteConfig) -> None: """Test that async config init calls async method directly.""" from typing import cast with patch.object(AsyncMigrationCommands, "init", new_callable=AsyncMock) as mock_init: commands = cast(AsyncMigrationCommands, create_migration_commands(async_config)) - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = str(Path(temp_dir) / "migrations") + migration_dir = str(tmp_path / "migrations") - await commands.init(migration_dir, package=True) + await commands.init(migration_dir, package=True) - # Verify the async method was called directly - mock_init.assert_called_once_with(migration_dir, package=True) + # Verify the async method was called directly + mock_init.assert_called_once_with(migration_dir, package=True) def test_migration_commands_sync_current_delegation(sync_config: SqliteConfig) -> None: @@ -232,30 +229,28 @@ def test_async_migration_commands_initialization(async_config: AiosqliteConfig) assert hasattr(commands, "runner") -def test_sync_migration_commands_init_creates_directory(sync_config: SqliteConfig) -> None: +def test_sync_migration_commands_init_creates_directory(tmp_path: Path, sync_config: SqliteConfig) -> None: """Test that SyncMigrationCommands init creates migration directory structure.""" commands = SyncMigrationCommands(sync_config) - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - commands.init(str(migration_dir), package=True) + commands.init(str(migration_dir), package=True) - assert migration_dir.exists() - assert (migration_dir / "__init__.py").exists() + assert migration_dir.exists() + assert (migration_dir / "__init__.py").exists() -def test_sync_migration_commands_init_without_package(sync_config: SqliteConfig) -> None: +def test_sync_migration_commands_init_without_package(tmp_path: Path, sync_config: SqliteConfig) -> None: """Test that SyncMigrationCommands init creates directory without __init__.py when package=False.""" commands = SyncMigrationCommands(sync_config) - with tempfile.TemporaryDirectory() as temp_dir: - migration_dir = Path(temp_dir) / "migrations" + migration_dir = tmp_path / "migrations" - commands.init(str(migration_dir), package=False) + commands.init(str(migration_dir), package=False) - assert migration_dir.exists() - assert not (migration_dir / "__init__.py").exists() + assert migration_dir.exists() + assert not (migration_dir / "__init__.py").exists() async def test_migration_commands_error_propagation(async_config: AiosqliteConfig) -> None: diff --git a/tests/unit/test_migrations/test_migration_execution.py b/tests/unit/test_migrations/test_migration_execution.py index d68f3f0f..67461d73 100644 --- a/tests/unit/test_migrations/test_migration_execution.py +++ b/tests/unit/test_migrations/test_migration_execution.py @@ -10,10 +10,6 @@ - Migration file processing """ -from __future__ import annotations - -import tempfile -from collections.abc import Generator from pathlib import Path from typing import Any from unittest.mock import Mock, patch @@ -26,21 +22,11 @@ pytestmark = pytest.mark.xdist_group("migrations") -@pytest.fixture -def temp_workspace() -> Generator[Path, None, None]: - """Create a temporary workspace for migration tests.""" - with tempfile.TemporaryDirectory() as temp_dir: - workspace = Path(temp_dir) - yield workspace - - @pytest.fixture def temp_workspace_with_migrations(tmp_path: Path) -> Path: """Create a temporary workspace with migrations directory for tests.""" - migrations_dir = tmp_path / "migrations" migrations_dir.mkdir() - return tmp_path diff --git a/tests/unit/test_migrations/test_migration_runner.py b/tests/unit/test_migrations/test_migration_runner.py index 044fe61f..ad54b2c5 100644 --- a/tests/unit/test_migrations/test_migration_runner.py +++ b/tests/unit/test_migrations/test_migration_runner.py @@ -9,7 +9,6 @@ - Error handling and validation """ -import tempfile import time from pathlib import Path from typing import Any @@ -188,55 +187,46 @@ def test_migration_runner_with_project_root() -> None: assert runner.project_root == project_root -def test_get_migration_files_sorting() -> None: +def test_get_migration_files_sorting(tmp_path: Path) -> None: """Test that migration files are properly sorted by version.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + (tmp_path / "0003_add_indexes.sql").write_text("-- Migration 3") + (tmp_path / "0001_initial.sql").write_text("-- Migration 1") + (tmp_path / "0010_final_touches.sql").write_text("-- Migration 10") + (tmp_path / "0002_add_users.sql").write_text("-- Migration 2") - (migrations_path / "0003_add_indexes.sql").write_text("-- Migration 3") - (migrations_path / "0001_initial.sql").write_text("-- Migration 1") - (migrations_path / "0010_final_touches.sql").write_text("-- Migration 10") - (migrations_path / "0002_add_users.sql").write_text("-- Migration 2") - - runner = create_migration_runner_with_sync_files(migrations_path) - files = runner.get_migration_files() + runner = create_migration_runner_with_sync_files(tmp_path) + files = runner.get_migration_files() - expected_order = ["0001", "0002", "0003", "0010"] - actual_order = [version for version, _ in files] + expected_order = ["0001", "0002", "0003", "0010"] + actual_order = [version for version, _ in files] - assert actual_order == expected_order + assert actual_order == expected_order -def test_get_migration_files_mixed_extensions() -> None: +def test_get_migration_files_mixed_extensions(tmp_path: Path) -> None: """Test migration file discovery with mixed SQL and Python files.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + (tmp_path / "0001_schema.sql").write_text("-- SQL Migration") + (tmp_path / "0002_data.py").write_text("# Data migration") + (tmp_path / "0003_more_schema.sql").write_text("-- Another SQL Migration") + (tmp_path / "README.md").write_text("# README") - (migrations_path / "0001_schema.sql").write_text("-- SQL Migration") - (migrations_path / "0002_data.py").write_text("# Data migration") - (migrations_path / "0003_more_schema.sql").write_text("-- Another SQL Migration") - (migrations_path / "README.md").write_text("# README") - - runner = create_migration_runner_with_sync_files(migrations_path) - files = runner.get_migration_files() + runner = create_migration_runner_with_sync_files(tmp_path) + files = runner.get_migration_files() - assert len(files) == 3 - assert files[0][0] == "0001" - assert files[1][0] == "0002" - assert files[2][0] == "0003" + assert len(files) == 3 + assert files[0][0] == "0001" + assert files[1][0] == "0002" + assert files[2][0] == "0003" - assert files[0][1].suffix == ".sql" - assert files[1][1].suffix == ".py" - assert files[2][1].suffix == ".sql" + assert files[0][1].suffix == ".sql" + assert files[1][1].suffix == ".py" + assert files[2][1].suffix == ".sql" -def test_load_migration_metadata_integration() -> None: +def test_load_migration_metadata_integration(tmp_path: Path) -> None: """Test full migration metadata loading process.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) - - migration_file = migrations_path / "0001_create_users.sql" - migration_content = """ + migration_file = tmp_path / "0001_create_users.sql" + migration_content = """ -- name: migrate-0001-up CREATE TABLE users ( id INTEGER PRIMARY KEY, @@ -248,38 +238,36 @@ def test_load_migration_metadata_integration() -> None: -- name: migrate-0001-down DROP TABLE users; """ - migration_file.write_text(migration_content) + migration_file.write_text(migration_content) - runner = create_migration_runner_with_metadata(migrations_path) + runner = create_migration_runner_with_metadata(tmp_path) - with ( - patch.object(type(runner.loader), "clear_cache"), - patch.object(type(runner.loader), "load_sql"), - patch.object(type(runner.loader), "has_query", side_effect=lambda q: True), - ): - with patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader: - mock_loader = Mock() - mock_loader.validate_migration_file = Mock() - mock_get_loader.return_value = mock_loader + with ( + patch.object(type(runner.loader), "clear_cache"), + patch.object(type(runner.loader), "load_sql"), + patch.object(type(runner.loader), "has_query", side_effect=lambda q: True), + ): + with patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader: + mock_loader = Mock() + mock_loader.validate_migration_file = Mock() + mock_get_loader.return_value = mock_loader - metadata = runner.load_migration(migration_file) + metadata = runner.load_migration(migration_file) - assert metadata["version"] == "0001" - assert metadata["description"] == "create_users" - assert metadata["file_path"] == migration_file - assert metadata["has_upgrade"] is True - assert metadata["has_downgrade"] is True - assert isinstance(metadata["checksum"], str) - assert len(metadata["checksum"]) == 32 - assert "loader" in metadata + assert metadata["version"] == "0001" + assert metadata["description"] == "create_users" + assert metadata["file_path"] == migration_file + assert metadata["has_upgrade"] is True + assert metadata["has_downgrade"] is True + assert isinstance(metadata["checksum"], str) + assert len(metadata["checksum"]) == 32 + assert "loader" in metadata -def test_load_migration_metadata_prefers_sql_description() -> None: - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) - migration_file = migrations_path / "0001_custom.sql" - migration_file.write_text( - """ +def test_load_migration_metadata_prefers_sql_description(tmp_path: Path) -> None: + migration_file = tmp_path / "0001_custom.sql" + migration_file.write_text( + """ -- SQLSpec Migration -- Description: Custom summary -- Author: Example @@ -287,51 +275,46 @@ def test_load_migration_metadata_prefers_sql_description() -> None: -- name: migrate-0001-up SELECT 1; """ - ) + ) - runner = create_migration_runner_with_metadata(migrations_path) + runner = create_migration_runner_with_metadata(tmp_path) - with ( - patch.object(type(runner.loader), "clear_cache"), - patch.object(type(runner.loader), "load_sql"), - patch.object(type(runner.loader), "has_query", return_value=True), - ): - metadata = runner.load_migration(migration_file) + with ( + patch.object(type(runner.loader), "clear_cache"), + patch.object(type(runner.loader), "load_sql"), + patch.object(type(runner.loader), "has_query", return_value=True), + ): + metadata = runner.load_migration(migration_file) - assert metadata["description"] == "Custom summary" + assert metadata["description"] == "Custom summary" -def test_load_migration_metadata_prefers_python_docstring() -> None: - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) - migration_file = migrations_path / "0002_feature.py" - migration_file.write_text('"""Description: Add feature"""\n\ndef up():\n return "SELECT 1"\n') +def test_load_migration_metadata_prefers_python_docstring(tmp_path: Path) -> None: + migration_file = tmp_path / "0002_feature.py" + migration_file.write_text('"""Description: Add feature"""\n\ndef up():\n return "SELECT 1"\n') - runner = create_migration_runner_with_metadata(migrations_path) + runner = create_migration_runner_with_metadata(tmp_path) - with ( - patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader, - patch("sqlspec.migrations.base.await_") as mock_await, - ): - mock_loader = Mock() - mock_loader.validate_migration_file = Mock() - mock_loader.get_up_sql = Mock(return_value=["SELECT 1"]) - mock_loader.get_down_sql = Mock(return_value=None) - mock_get_loader.return_value = mock_loader - mock_await.return_value = Mock(return_value=True) + with ( + patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader, + patch("sqlspec.migrations.base.await_") as mock_await, + ): + mock_loader = Mock() + mock_loader.validate_migration_file = Mock() + mock_loader.get_up_sql = Mock(return_value=["SELECT 1"]) + mock_loader.get_down_sql = Mock(return_value=None) + mock_get_loader.return_value = mock_loader + mock_await.return_value = Mock(return_value=True) - metadata = runner.load_migration(migration_file) + metadata = runner.load_migration(migration_file) - assert metadata["description"] == "Add feature" + assert metadata["description"] == "Add feature" -def test_load_migration_metadata_python_file() -> None: +def test_load_migration_metadata_python_file(tmp_path: Path) -> None: """Test metadata loading for Python migration files.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) - - migration_file = migrations_path / "0001_data_migration.py" - python_content = ''' + migration_file = tmp_path / "0001_data_migration.py" + python_content = ''' def up(): """Upgrade migration.""" return [ @@ -346,28 +329,28 @@ def down(): "DELETE FROM users WHERE name = 'admin'" ] ''' - migration_file.write_text(python_content) + migration_file.write_text(python_content) - runner = create_migration_runner_with_metadata(migrations_path) + runner = create_migration_runner_with_metadata(tmp_path) - with ( - patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader, - patch("sqlspec.migrations.base.await_") as mock_await, - ): - mock_loader = Mock() - mock_loader.validate_migration_file = Mock() - mock_loader.get_up_sql = Mock() - mock_loader.get_down_sql = Mock() - mock_get_loader.return_value = mock_loader + with ( + patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader, + patch("sqlspec.migrations.base.await_") as mock_await, + ): + mock_loader = Mock() + mock_loader.validate_migration_file = Mock() + mock_loader.get_up_sql = Mock() + mock_loader.get_down_sql = Mock() + mock_get_loader.return_value = mock_loader - mock_await.return_value = Mock(return_value=True) + mock_await.return_value = Mock(return_value=True) - metadata = runner.load_migration(migration_file) + metadata = runner.load_migration(migration_file) - assert metadata["version"] == "0001" - assert metadata["description"] == "data_migration" - assert metadata["has_upgrade"] is True - assert metadata["has_downgrade"] is True + assert metadata["version"] == "0001" + assert metadata["description"] == "data_migration" + assert metadata["has_upgrade"] is True + assert metadata["has_downgrade"] is True def test_get_migration_sql_upgrade_success() -> None: @@ -540,37 +523,31 @@ def test_get_migration_sql_none_statements() -> None: assert result is None -def test_invalid_migration_version_handling() -> None: +def test_invalid_migration_version_handling(tmp_path: Path) -> None: """Test handling of invalid migration version formats.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) - - invalid_file = migrations_path / "invalid_version_format.sql" - invalid_file.write_text("CREATE TABLE test (id INTEGER);") + invalid_file = tmp_path / "invalid_version_format.sql" + invalid_file.write_text("CREATE TABLE test (id INTEGER);") - runner = create_migration_runner_with_sync_files(migrations_path) - files = runner.get_migration_files() + runner = create_migration_runner_with_sync_files(tmp_path) + files = runner.get_migration_files() - assert len(files) == 0 + assert len(files) == 0 -def test_corrupted_migration_file_handling() -> None: +def test_corrupted_migration_file_handling(tmp_path: Path) -> None: """Test handling of corrupted migration files.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + corrupted_file = tmp_path / "0001_corrupted.sql" + corrupted_file.write_text("This is not a valid migration file content") - corrupted_file = migrations_path / "0001_corrupted.sql" - corrupted_file.write_text("This is not a valid migration file content") + runner = create_migration_runner_with_metadata(tmp_path) - runner = create_migration_runner_with_metadata(migrations_path) + with patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader: + mock_loader = Mock() + mock_loader.validate_migration_file.side_effect = Exception("Validation failed") + mock_get_loader.return_value = mock_loader - with patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader: - mock_loader = Mock() - mock_loader.validate_migration_file.side_effect = Exception("Validation failed") - mock_get_loader.return_value = mock_loader - - with pytest.raises(Exception): - runner.load_migration(corrupted_file) + with pytest.raises(Exception): + runner.load_migration(corrupted_file) def test_missing_migrations_directory() -> None: @@ -582,60 +559,54 @@ def test_missing_migrations_directory() -> None: assert files == [] -def test_large_migration_file_handling() -> None: +def test_large_migration_file_handling(tmp_path: Path) -> None: """Test handling of large migration files.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) + large_file = tmp_path / "0001_large_migration.sql" - large_file = migrations_path / "0001_large_migration.sql" - - large_content_parts = [ - """ + large_content_parts = [ + """ -- name: migrate-0001-up CREATE TABLE large_table ( id INTEGER PRIMARY KEY, data TEXT ); """ - ] + ] - large_content_parts.extend(f"INSERT INTO large_table (data) VALUES ('data_{i:04d}');" for i in range(1000)) + large_content_parts.extend(f"INSERT INTO large_table (data) VALUES ('data_{i:04d}');" for i in range(1000)) - large_content_parts.append(""" + large_content_parts.append(""" -- name: migrate-0001-down DROP TABLE large_table; """) - large_content = "\n".join(large_content_parts) - large_file.write_text(large_content) + large_content = "\n".join(large_content_parts) + large_file.write_text(large_content) - runner = create_migration_runner_with_metadata(migrations_path) + runner = create_migration_runner_with_metadata(tmp_path) - with ( - patch.object(type(runner.loader), "clear_cache"), - patch.object(type(runner.loader), "load_sql"), - patch.object(type(runner.loader), "has_query", return_value=True), - ): - with patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader: - mock_loader = Mock() - mock_loader.validate_migration_file = Mock() - mock_get_loader.return_value = mock_loader + with ( + patch.object(type(runner.loader), "clear_cache"), + patch.object(type(runner.loader), "load_sql"), + patch.object(type(runner.loader), "has_query", return_value=True), + ): + with patch("sqlspec.migrations.base.get_migration_loader") as mock_get_loader: + mock_loader = Mock() + mock_loader.validate_migration_file = Mock() + mock_get_loader.return_value = mock_loader - metadata = runner.load_migration(large_file) + metadata = runner.load_migration(large_file) - assert metadata["version"] == "0001" - assert metadata["description"] == "large_migration" - assert len(metadata["checksum"]) == 32 + assert metadata["version"] == "0001" + assert metadata["description"] == "large_migration" + assert len(metadata["checksum"]) == 32 -def test_many_migration_files_performance() -> None: +def test_many_migration_files_performance(tmp_path: Path) -> None: """Test performance with many migration files.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) - - for i in range(100): - migration_file = migrations_path / f"{i + 1:04d}_migration_{i}.sql" - migration_file.write_text(f""" + for i in range(100): + migration_file = tmp_path / f"{i + 1:04d}_migration_{i}.sql" + migration_file.write_text(f""" -- name: migrate-{i + 1:04d}-up CREATE TABLE test_table_{i} (id INTEGER PRIMARY KEY); @@ -643,18 +614,18 @@ def test_many_migration_files_performance() -> None: DROP TABLE test_table_{i}; """) - runner = create_migration_runner_with_sync_files(migrations_path) + runner = create_migration_runner_with_sync_files(tmp_path) - files = runner.get_migration_files() + files = runner.get_migration_files() - assert len(files) == 100 + assert len(files) == 100 - for i, (version, _) in enumerate(files): - expected_version = f"{i + 1:04d}" - assert version == expected_version + for i, (version, _) in enumerate(files): + expected_version = f"{i + 1:04d}" + assert version == expected_version -def test_sql_loader_caches_files() -> None: +def test_sql_loader_caches_files(tmp_path: Path) -> None: """Test that SQL migration files leverage CoreSQLFileLoader caching. Verifies fix for bug #118 - duplicate SQL loading during migrations. @@ -665,38 +636,35 @@ def test_sql_loader_caches_files() -> None: from sqlspec.migrations.loaders import SQLFileLoader - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) - - migration_file = migrations_path / "0001_test_migration.sql" - migration_content = """ + migration_file = tmp_path / "0001_test_migration.sql" + migration_content = """ -- name: migrate-0001-up CREATE TABLE test (id INTEGER PRIMARY KEY); -- name: migrate-0001-down DROP TABLE test; """ - migration_file.write_text(migration_content) + migration_file.write_text(migration_content) - sql_loader = SQLFileLoader() + sql_loader = SQLFileLoader() - async def test_operations() -> None: - sql_loader.validate_migration_file(migration_file) - path_str = str(migration_file) - assert path_str in sql_loader.sql_loader._files - assert sql_loader.sql_loader.has_query("migrate-0001-up") - assert sql_loader.sql_loader.has_query("migrate-0001-down") + async def test_operations() -> None: + sql_loader.validate_migration_file(migration_file) + path_str = str(migration_file) + assert path_str in sql_loader.sql_loader._files + assert sql_loader.sql_loader.has_query("migrate-0001-up") + assert sql_loader.sql_loader.has_query("migrate-0001-down") - await sql_loader.get_up_sql(migration_file) - assert path_str in sql_loader.sql_loader._files + await sql_loader.get_up_sql(migration_file) + assert path_str in sql_loader.sql_loader._files - await sql_loader.get_down_sql(migration_file) - assert path_str in sql_loader.sql_loader._files + await sql_loader.get_down_sql(migration_file) + assert path_str in sql_loader.sql_loader._files - asyncio.run(test_operations()) + asyncio.run(test_operations()) -def test_no_duplicate_loading_during_migration_execution() -> None: +def test_no_duplicate_loading_during_migration_execution(tmp_path: Path) -> None: """Test that SQL files are loaded exactly once during migration execution. Verifies fix for issue #118 - validates that running a migration loads @@ -707,11 +675,8 @@ def test_no_duplicate_loading_during_migration_execution() -> None: from sqlspec.migrations.loaders import SQLFileLoader - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) - - migration_file = migrations_path / "0001_create_users.sql" - migration_content = """ + migration_file = tmp_path / "0001_create_users.sql" + migration_content = """ -- name: migrate-0001-up CREATE TABLE users ( id INTEGER PRIMARY KEY, @@ -721,32 +686,32 @@ def test_no_duplicate_loading_during_migration_execution() -> None: -- name: migrate-0001-down DROP TABLE users; """ - migration_file.write_text(migration_content) + migration_file.write_text(migration_content) - sql_loader = SQLFileLoader() + sql_loader = SQLFileLoader() - async def test_migration_workflow() -> None: - sql_loader.validate_migration_file(migration_file) + async def test_migration_workflow() -> None: + sql_loader.validate_migration_file(migration_file) - path_str = str(migration_file) - assert path_str in sql_loader.sql_loader._files, "File should be loaded after validation" - assert sql_loader.sql_loader.has_query("migrate-0001-up") - assert sql_loader.sql_loader.has_query("migrate-0001-down") + path_str = str(migration_file) + assert path_str in sql_loader.sql_loader._files, "File should be loaded after validation" + assert sql_loader.sql_loader.has_query("migrate-0001-up") + assert sql_loader.sql_loader.has_query("migrate-0001-down") - file_count_after_validation = len(sql_loader.sql_loader._files) + file_count_after_validation = len(sql_loader.sql_loader._files) - await sql_loader.get_up_sql(migration_file) - file_count_after_up = len(sql_loader.sql_loader._files) - assert file_count_after_validation == file_count_after_up, "get_up_sql should not load additional files" + await sql_loader.get_up_sql(migration_file) + file_count_after_up = len(sql_loader.sql_loader._files) + assert file_count_after_validation == file_count_after_up, "get_up_sql should not load additional files" - await sql_loader.get_down_sql(migration_file) - file_count_after_down = len(sql_loader.sql_loader._files) - assert file_count_after_up == file_count_after_down, "get_down_sql should not load additional files" + await sql_loader.get_down_sql(migration_file) + file_count_after_down = len(sql_loader.sql_loader._files) + assert file_count_after_up == file_count_after_down, "get_down_sql should not load additional files" - asyncio.run(test_migration_workflow()) + asyncio.run(test_migration_workflow()) -def test_sql_file_loader_counter_accuracy_single_file() -> None: +def test_sql_file_loader_counter_accuracy_single_file(tmp_path: Path) -> None: """Test SQLFileLoader caching behavior for single file loading. Verifies fix for issue #118 (Solution 2) - ensures that load_sql() @@ -755,11 +720,8 @@ def test_sql_file_loader_counter_accuracy_single_file() -> None: """ from sqlspec.loader import SQLFileLoader - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - test_file = temp_path / "test_queries.sql" - test_content = """ + test_file = tmp_path / "test_queries.sql" + test_content = """ -- name: get_user SELECT * FROM users WHERE id = :id; @@ -769,25 +731,25 @@ def test_sql_file_loader_counter_accuracy_single_file() -> None: -- name: delete_user DELETE FROM users WHERE id = :id; """ - test_file.write_text(test_content) + test_file.write_text(test_content) - loader = SQLFileLoader() + loader = SQLFileLoader() - loader.load_sql(test_file) - path_str = str(test_file) - assert path_str in loader._files, "First load should add file to cache" - assert len(loader._queries) == 3, "First load should parse 3 queries" + loader.load_sql(test_file) + path_str = str(test_file) + assert path_str in loader._files, "First load should add file to cache" + assert len(loader._queries) == 3, "First load should parse 3 queries" - query_count_before_reload = len(loader._queries) - file_count_before_reload = len(loader._files) + query_count_before_reload = len(loader._queries) + file_count_before_reload = len(loader._files) - loader.load_sql(test_file) + loader.load_sql(test_file) - assert len(loader._queries) == query_count_before_reload, "Second load should not add new queries (cached)" - assert len(loader._files) == file_count_before_reload, "Second load should not add new files (cached)" + assert len(loader._queries) == query_count_before_reload, "Second load should not add new queries (cached)" + assert len(loader._files) == file_count_before_reload, "Second load should not add new files (cached)" -def test_sql_file_loader_counter_accuracy_directory() -> None: +def test_sql_file_loader_counter_accuracy_directory(tmp_path: Path) -> None: """Test SQLFileLoader caching behavior for directory loading. Verifies that _load_directory() properly caches files and doesn't @@ -795,37 +757,34 @@ def test_sql_file_loader_counter_accuracy_directory() -> None: """ from sqlspec.loader import SQLFileLoader - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - file1 = temp_path / "queries1.sql" - file1.write_text(""" + file1 = tmp_path / "queries1.sql" + file1.write_text(""" -- name: query1 SELECT 1; """) - file2 = temp_path / "queries2.sql" - file2.write_text(""" + file2 = tmp_path / "queries2.sql" + file2.write_text(""" -- name: query2 SELECT 2; """) - loader = SQLFileLoader() + loader = SQLFileLoader() - loader.load_sql(temp_path) - assert len(loader._files) == 2, "First load should add 2 files to cache" - assert len(loader._queries) == 2, "First load should parse 2 queries" + loader.load_sql(tmp_path) + assert len(loader._files) == 2, "First load should add 2 files to cache" + assert len(loader._queries) == 2, "First load should parse 2 queries" - query_count_before_reload = len(loader._queries) - file_count_before_reload = len(loader._files) + query_count_before_reload = len(loader._queries) + file_count_before_reload = len(loader._files) - loader.load_sql(temp_path) + loader.load_sql(tmp_path) - assert len(loader._queries) == query_count_before_reload, "Second load should not add new queries (all cached)" - assert len(loader._files) == file_count_before_reload, "Second load should not add new files (all cached)" + assert len(loader._queries) == query_count_before_reload, "Second load should not add new queries (all cached)" + assert len(loader._files) == file_count_before_reload, "Second load should not add new files (all cached)" -def test_migration_workflow_single_load_design() -> None: +def test_migration_workflow_single_load_design(tmp_path: Path) -> None: """Test that migration workflow respects single-load design. Verifies fix for issue #118 (Solution 1) - confirms that: @@ -839,49 +798,46 @@ def test_migration_workflow_single_load_design() -> None: from sqlspec.migrations.loaders import SQLFileLoader - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) - - migration_file = migrations_path / "0001_test.sql" - migration_content = """ + migration_file = tmp_path / "0001_test.sql" + migration_content = """ -- name: migrate-0001-up CREATE TABLE test_table (id INTEGER); -- name: migrate-0001-down DROP TABLE test_table; """ - migration_file.write_text(migration_content) + migration_file.write_text(migration_content) - sql_loader = SQLFileLoader() + sql_loader = SQLFileLoader() - async def test_workflow() -> None: - sql_loader.validate_migration_file(migration_file) + async def test_workflow() -> None: + sql_loader.validate_migration_file(migration_file) - path_str = str(migration_file) - assert path_str in sql_loader.sql_loader._files, "File should be loaded after validation" - assert sql_loader.sql_loader.has_query("migrate-0001-up") - assert sql_loader.sql_loader.has_query("migrate-0001-down") + path_str = str(migration_file) + assert path_str in sql_loader.sql_loader._files, "File should be loaded after validation" + assert sql_loader.sql_loader.has_query("migrate-0001-up") + assert sql_loader.sql_loader.has_query("migrate-0001-down") - file_count_before_up = len(sql_loader.sql_loader._files) - up_sql = await sql_loader.get_up_sql(migration_file) - file_count_after_up = len(sql_loader.sql_loader._files) + file_count_before_up = len(sql_loader.sql_loader._files) + up_sql = await sql_loader.get_up_sql(migration_file) + file_count_after_up = len(sql_loader.sql_loader._files) - assert file_count_before_up == file_count_after_up, "get_up_sql() should not load additional files" - assert len(up_sql) == 1 - assert "CREATE TABLE test_table" in up_sql[0] + assert file_count_before_up == file_count_after_up, "get_up_sql() should not load additional files" + assert len(up_sql) == 1 + assert "CREATE TABLE test_table" in up_sql[0] - file_count_before_down = len(sql_loader.sql_loader._files) - down_sql = await sql_loader.get_down_sql(migration_file) - file_count_after_down = len(sql_loader.sql_loader._files) + file_count_before_down = len(sql_loader.sql_loader._files) + down_sql = await sql_loader.get_down_sql(migration_file) + file_count_after_down = len(sql_loader.sql_loader._files) - assert file_count_before_down == file_count_after_down, "get_down_sql() should not load additional files" - assert len(down_sql) == 1 - assert "DROP TABLE test_table" in down_sql[0] + assert file_count_before_down == file_count_after_down, "get_down_sql() should not load additional files" + assert len(down_sql) == 1 + assert "DROP TABLE test_table" in down_sql[0] - asyncio.run(test_workflow()) + asyncio.run(test_workflow()) -def test_migration_loader_does_not_reload_on_get_sql_calls() -> None: +def test_migration_loader_does_not_reload_on_get_sql_calls(tmp_path: Path) -> None: """Test that get_up_sql and get_down_sql do not trigger file reloads. Verifies that after validate_migration_file() loads the file, @@ -893,38 +849,35 @@ def test_migration_loader_does_not_reload_on_get_sql_calls() -> None: from sqlspec.loader import SQLFileLoader as CoreSQLFileLoader from sqlspec.migrations.loaders import SQLFileLoader - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) - - migration_file = migrations_path / "0001_schema.sql" - migration_content = """ + migration_file = tmp_path / "0001_schema.sql" + migration_content = """ -- name: migrate-0001-up CREATE TABLE products (id INTEGER, name TEXT); -- name: migrate-0001-down DROP TABLE products; """ - migration_file.write_text(migration_content) + migration_file.write_text(migration_content) - sql_loader = SQLFileLoader() + sql_loader = SQLFileLoader() - call_counts = {"load_sql": 0} - original_load_sql = CoreSQLFileLoader.load_sql + call_counts = {"load_sql": 0} + original_load_sql = CoreSQLFileLoader.load_sql - def counting_load_sql(self: CoreSQLFileLoader, *args: Any, **kwargs: Any) -> None: - call_counts["load_sql"] += 1 - return original_load_sql(self, *args, **kwargs) + def counting_load_sql(self: CoreSQLFileLoader, *args: Any, **kwargs: Any) -> None: + call_counts["load_sql"] += 1 + return original_load_sql(self, *args, **kwargs) - with patch.object(CoreSQLFileLoader, "load_sql", counting_load_sql): + with patch.object(CoreSQLFileLoader, "load_sql", counting_load_sql): - async def test_no_reload() -> None: - sql_loader.validate_migration_file(migration_file) - assert call_counts["load_sql"] == 1, "validate_migration_file should call load_sql exactly once" + async def test_no_reload() -> None: + sql_loader.validate_migration_file(migration_file) + assert call_counts["load_sql"] == 1, "validate_migration_file should call load_sql exactly once" - await sql_loader.get_up_sql(migration_file) - assert call_counts["load_sql"] == 1, "get_up_sql should NOT call load_sql (should use cache)" + await sql_loader.get_up_sql(migration_file) + assert call_counts["load_sql"] == 1, "get_up_sql should NOT call load_sql (should use cache)" - await sql_loader.get_down_sql(migration_file) - assert call_counts["load_sql"] == 1, "get_down_sql should NOT call load_sql (should use cache)" + await sql_loader.get_down_sql(migration_file) + assert call_counts["load_sql"] == 1, "get_down_sql should NOT call load_sql (should use cache)" - asyncio.run(test_no_reload()) + asyncio.run(test_no_reload()) diff --git a/tests/unit/test_migrations/test_null_handling_fixes.py b/tests/unit/test_migrations/test_null_handling_fixes.py index 23a3baed..b1ecdad8 100644 --- a/tests/unit/test_migrations/test_null_handling_fixes.py +++ b/tests/unit/test_migrations/test_null_handling_fixes.py @@ -1,6 +1,5 @@ """Test cases for null handling fixes in migration system.""" -import tempfile from pathlib import Path import pytest @@ -10,102 +9,99 @@ from sqlspec.utils.version import is_sequential_version, is_timestamp_version, parse_version -class TestNullHandlingFixes: - """Test fixes for None value handling in migrations.""" - - def test_parse_version_with_none(self): - """Test parse_version handles None gracefully.""" - with pytest.raises(ValueError, match="Invalid migration version: version string is None or empty"): - parse_version(None) - - def test_parse_version_with_empty_string(self): - """Test parse_version handles empty string gracefully.""" - with pytest.raises(ValueError, match="Invalid migration version: version string is None or empty"): - parse_version("") - - def test_parse_version_with_whitespace_only(self): - """Test parse_version handles whitespace-only strings.""" - with pytest.raises(ValueError, match="Invalid migration version: version string is None or empty"): - parse_version(" ") - - def test_parse_version_valid_formats_still_work(self): - """Test that valid version formats still work after fixes.""" - # Sequential versions - result = parse_version("0001") - assert result.type.value == "sequential" - assert result.sequence == 1 - - result = parse_version("9999") - assert result.type.value == "sequential" - assert result.sequence == 9999 - - # Timestamp versions - result = parse_version("20251011120000") - assert result.type.value == "timestamp" - assert result.timestamp is not None - - # Extension versions - result = parse_version("ext_litestar_0001") - assert result.type.value == "sequential" # Base is sequential - assert result.extension == "litestar" - - def test_migration_fixer_handles_none_gracefully(self): - """Test MigrationFixer.update_file_content handles None values.""" - with tempfile.TemporaryDirectory() as temp_dir: - migrations_path = Path(temp_dir) - fixer = MigrationFixer(migrations_path) - - test_file = migrations_path / "test.sql" - test_file.write_text("-- Test content") - - # Should not crash with None values - fixer.update_file_content(test_file, None, "0001") - fixer.update_file_content(test_file, "0001", None) - fixer.update_file_content(test_file, None, None) - - # File should remain unchanged - content = test_file.read_text() - assert content == "-- Test content" - - def test_validation_filters_none_values(self): - """Test migration validation filters None values properly.""" - # Should not crash with None values in lists - gaps = detect_out_of_order_migrations( - pending_versions=["0001", None, "0003", ""], applied_versions=[None, "0002", " ", "0004"] - ) - - # Should only process valid versions - assert len(gaps) >= 0 # Should not crash - - def test_sequential_pattern_edge_cases(self): - """Test sequential pattern handles edge cases.""" - assert is_sequential_version("0001") - assert is_sequential_version("9999") - assert is_sequential_version("10000") - assert not is_sequential_version("20251011120000") # Timestamp - assert not is_sequential_version("abc") - assert not is_sequential_version("") - assert not is_sequential_version(None) - - def test_timestamp_pattern_edge_cases(self): - """Test timestamp pattern handles edge cases.""" - assert is_timestamp_version("20251011120000") - assert is_timestamp_version("20250101000000") - assert is_timestamp_version("20251231235959") - assert not is_timestamp_version("0001") # Sequential - assert not is_timestamp_version("2025101112000") # Too short - assert not is_timestamp_version("202510111200000") # Too long - assert not is_timestamp_version("") - assert not is_timestamp_version(None) - - def test_error_messages_are_descriptive(self): - """Test that error messages are helpful for debugging.""" - try: - parse_version(None) - except ValueError as e: - assert "version string is None or empty" in str(e) - - try: - parse_version("") - except ValueError as e: - assert "version string is None or empty" in str(e) +def test_parse_version_with_none() -> None: + """Test parse_version handles None gracefully.""" + with pytest.raises(ValueError, match="Invalid migration version: version string is None or empty"): + parse_version(None) + + +def test_parse_version_with_empty_string() -> None: + """Test parse_version handles empty string gracefully.""" + with pytest.raises(ValueError, match="Invalid migration version: version string is None or empty"): + parse_version("") + + +def test_parse_version_with_whitespace_only() -> None: + """Test parse_version handles whitespace-only strings.""" + with pytest.raises(ValueError, match="Invalid migration version: version string is None or empty"): + parse_version(" ") + + +def test_parse_version_valid_formats_still_work() -> None: + """Test that valid version formats still work after fixes.""" + result = parse_version("0001") + assert result.type.value == "sequential" + assert result.sequence == 1 + + result = parse_version("9999") + assert result.type.value == "sequential" + assert result.sequence == 9999 + + result = parse_version("20251011120000") + assert result.type.value == "timestamp" + assert result.timestamp is not None + + result = parse_version("ext_litestar_0001") + assert result.type.value == "sequential" + assert result.extension == "litestar" + + +def test_migration_fixer_handles_none_gracefully(tmp_path: Path) -> None: + """Test MigrationFixer.update_file_content handles None values.""" + migrations_path = tmp_path + fixer = MigrationFixer(migrations_path) + + test_file = migrations_path / "test.sql" + test_file.write_text("-- Test content") + + fixer.update_file_content(test_file, None, "0001") + fixer.update_file_content(test_file, "0001", None) + fixer.update_file_content(test_file, None, None) + + content = test_file.read_text() + assert content == "-- Test content" + + +def test_validation_filters_none_values() -> None: + """Test migration validation filters None values properly.""" + gaps = detect_out_of_order_migrations( + pending_versions=["0001", None, "0003", ""], applied_versions=[None, "0002", " ", "0004"] + ) + + assert len(gaps) >= 0 + + +def test_sequential_pattern_edge_cases() -> None: + """Test sequential pattern handles edge cases.""" + assert is_sequential_version("0001") + assert is_sequential_version("9999") + assert is_sequential_version("10000") + assert not is_sequential_version("20251011120000") + assert not is_sequential_version("abc") + assert not is_sequential_version("") + assert not is_sequential_version(None) + + +def test_timestamp_pattern_edge_cases() -> None: + """Test timestamp pattern handles edge cases.""" + assert is_timestamp_version("20251011120000") + assert is_timestamp_version("20250101000000") + assert is_timestamp_version("20251231235959") + assert not is_timestamp_version("0001") + assert not is_timestamp_version("2025101112000") + assert not is_timestamp_version("202510111200000") + assert not is_timestamp_version("") + assert not is_timestamp_version(None) + + +def test_error_messages_are_descriptive() -> None: + """Test that error messages are helpful for debugging.""" + try: + parse_version(None) + except ValueError as e: + assert "version string is None or empty" in str(e) + + try: + parse_version("") + except ValueError as e: + assert "version string is None or empty" in str(e) diff --git a/tests/unit/test_storage/test_fsspec_backend.py b/tests/unit/test_storage/test_fsspec_backend.py index b86198b0..2121b004 100644 --- a/tests/unit/test_storage/test_fsspec_backend.py +++ b/tests/unit/test_storage/test_fsspec_backend.py @@ -1,6 +1,5 @@ """Unit tests for FSSpecBackend.""" -import tempfile from pathlib import Path from typing import Any @@ -43,228 +42,215 @@ def test_from_config() -> None: @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -def test_write_and_read_bytes() -> None: +def test_write_and_read_bytes(tmp_path: Path) -> None: """Test write and read bytes operations.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) - test_data = b"test data content" + store = FSSpecBackend("file", base_path=str(tmp_path)) + test_data = b"test data content" - store.write_bytes("test_file.bin", test_data) - result = store.read_bytes("test_file.bin") + store.write_bytes("test_file.bin", test_data) + result = store.read_bytes("test_file.bin") - assert result == test_data + assert result == test_data @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -def test_write_and_read_text() -> None: +def test_write_and_read_text(tmp_path: Path) -> None: """Test write and read text operations.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) - test_text = "test text content\nwith multiple lines" + store = FSSpecBackend("file", base_path=str(tmp_path)) + test_text = "test text content\nwith multiple lines" - store.write_text("test_file.txt", test_text) - result = store.read_text("test_file.txt") + store.write_text("test_file.txt", test_text) + result = store.read_text("test_file.txt") - assert result == test_text + assert result == test_text @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -def test_exists() -> None: +def test_exists(tmp_path: Path) -> None: """Test exists operation.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - assert not store.exists("nonexistent.txt") + assert not store.exists("nonexistent.txt") - store.write_text("existing.txt", "content") - assert store.exists("existing.txt") + store.write_text("existing.txt", "content") + assert store.exists("existing.txt") @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -def test_delete() -> None: +def test_delete(tmp_path: Path) -> None: """Test delete operation.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - store.write_text("to_delete.txt", "content") - assert store.exists("to_delete.txt") + store.write_text("to_delete.txt", "content") + assert store.exists("to_delete.txt") - store.delete("to_delete.txt") - assert not store.exists("to_delete.txt") + store.delete("to_delete.txt") + assert not store.exists("to_delete.txt") @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -def test_copy() -> None: +def test_copy(tmp_path: Path) -> None: """Test copy operation.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) - original_content = "original content" + store = FSSpecBackend("file", base_path=str(tmp_path)) + original_content = "original content" - store.write_text("original.txt", original_content) - store.copy("original.txt", "copied.txt") + store.write_text("original.txt", original_content) + store.copy("original.txt", "copied.txt") - assert store.exists("copied.txt") - assert store.read_text("copied.txt") == original_content + assert store.exists("copied.txt") + assert store.read_text("copied.txt") == original_content @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -def test_move() -> None: +def test_move(tmp_path: Path) -> None: """Test move operation.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) - original_content = "content to move" + store = FSSpecBackend("file", base_path=str(tmp_path)) + original_content = "content to move" - store.write_text("original.txt", original_content) - store.move("original.txt", "moved.txt") + store.write_text("original.txt", original_content) + store.move("original.txt", "moved.txt") - assert not store.exists("original.txt") - assert store.exists("moved.txt") - assert store.read_text("moved.txt") == original_content + assert not store.exists("original.txt") + assert store.exists("moved.txt") + assert store.read_text("moved.txt") == original_content @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -def test_list_objects() -> None: +def test_list_objects(tmp_path: Path) -> None: """Test list_objects operation.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - # Create test files - store.write_text("file1.txt", "content1") - store.write_text("file2.txt", "content2") - store.write_text("subdir/file3.txt", "content3") + # Create test files + store.write_text("file1.txt", "content1") + store.write_text("file2.txt", "content2") + store.write_text("subdir/file3.txt", "content3") - # List all objects - all_objects = store.list_objects() - assert any("file1.txt" in obj for obj in all_objects) - assert any("file2.txt" in obj for obj in all_objects) - assert any("file3.txt" in obj for obj in all_objects) + # List all objects + all_objects = store.list_objects() + assert any("file1.txt" in obj for obj in all_objects) + assert any("file2.txt" in obj for obj in all_objects) + assert any("file3.txt" in obj for obj in all_objects) @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -def test_glob() -> None: +def test_glob(tmp_path: Path) -> None: """Test glob pattern matching.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - # Create test files - store.write_text("test1.sql", "SELECT 1") - store.write_text("test2.sql", "SELECT 2") - store.write_text("config.json", "{}") + # Create test files + store.write_text("test1.sql", "SELECT 1") + store.write_text("test2.sql", "SELECT 2") + store.write_text("config.json", "{}") - # Test glob patterns - sql_files = store.glob("*.sql") - assert any("test1.sql" in obj for obj in sql_files) - assert any("test2.sql" in obj for obj in sql_files) - assert not any("config.json" in obj for obj in sql_files) + # Test glob patterns + sql_files = store.glob("*.sql") + assert any("test1.sql" in obj for obj in sql_files) + assert any("test2.sql" in obj for obj in sql_files) + assert not any("config.json" in obj for obj in sql_files) @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -def test_get_metadata() -> None: +def test_get_metadata(tmp_path: Path) -> None: """Test get_metadata operation.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) - test_content = "test content for metadata" + store = FSSpecBackend("file", base_path=str(tmp_path)) + test_content = "test content for metadata" - store.write_text("test_file.txt", test_content) - metadata = store.get_metadata("test_file.txt") + store.write_text("test_file.txt", test_content) + metadata = store.get_metadata("test_file.txt") - assert "size" in metadata - assert "exists" in metadata - assert metadata["exists"] is True + assert "size" in metadata + assert "exists" in metadata + assert metadata["exists"] is True @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -def test_is_object_and_is_path() -> None: +def test_is_object_and_is_path(tmp_path: Path) -> None: """Test is_object and is_path operations.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - store.write_text("file.txt", "content") - Path(temp_dir, "subdir").mkdir() + store.write_text("file.txt", "content") + (tmp_path / "subdir").mkdir() - assert store.is_object("file.txt") - assert not store.is_object("subdir") - assert not store.is_path("file.txt") - assert store.is_path("subdir") + assert store.is_object("file.txt") + assert not store.is_object("subdir") + assert not store.is_path("file.txt") + assert store.is_path("subdir") @pytest.mark.skipif(not FSSPEC_INSTALLED or not PYARROW_INSTALLED, reason="fsspec or PyArrow not installed") -def test_write_and_read_arrow() -> None: +def test_write_and_read_arrow(tmp_path: Path) -> None: """Test write and read Arrow table operations.""" import pyarrow as pa from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - # Create test Arrow table - data: dict[str, Any] = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"], "score": [95.5, 87.0, 92.3]} - table = pa.table(data) + # Create test Arrow table + data: dict[str, Any] = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"], "score": [95.5, 87.0, 92.3]} + table = pa.table(data) - store.write_arrow("test_data.parquet", table) - result = store.read_arrow("test_data.parquet") + store.write_arrow("test_data.parquet", table) + result = store.read_arrow("test_data.parquet") - assert result.equals(table) + assert result.equals(table) @pytest.mark.skipif(not FSSPEC_INSTALLED or not PYARROW_INSTALLED, reason="fsspec or PyArrow not installed") -def test_stream_arrow() -> None: +def test_stream_arrow(tmp_path: Path) -> None: """Test stream Arrow record batches.""" import pyarrow as pa from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - # Create test Arrow table - data: dict[str, Any] = {"id": [1, 2, 3, 4, 5], "value": ["a", "b", "c", "d", "e"]} - table = pa.table(data) + # Create test Arrow table + data: dict[str, Any] = {"id": [1, 2, 3, 4, 5], "value": ["a", "b", "c", "d", "e"]} + table = pa.table(data) - store.write_arrow("stream_test.parquet", table) + store.write_arrow("stream_test.parquet", table) - # Stream record batches - batches = list(store.stream_arrow("stream_test.parquet")) - assert len(batches) > 0 + # Stream record batches + batches = list(store.stream_arrow("stream_test.parquet")) + assert len(batches) > 0 - # Verify we can read the data - reconstructed = pa.Table.from_batches(batches) - assert reconstructed.equals(table) + # Verify we can read the data + reconstructed = pa.Table.from_batches(batches) + assert reconstructed.equals(table) @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -def test_sign_returns_uri() -> None: +def test_sign_returns_uri(tmp_path: Path) -> None: """Test sign returns URI for files.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - store.write_text("test.txt", "content") - signed_url = store.sign("test.txt") + store.write_text("test.txt", "content") + signed_url = store.sign("test.txt") - assert "test.txt" in signed_url + assert "test.txt" in signed_url def test_fsspec_not_installed() -> None: @@ -282,196 +268,185 @@ def test_fsspec_not_installed() -> None: @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -async def test_async_write_and_read_bytes() -> None: +async def test_async_write_and_read_bytes(tmp_path: Path) -> None: """Test async write and read bytes operations.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) - test_data = b"async test data content" + store = FSSpecBackend("file", base_path=str(tmp_path)) + test_data = b"async test data content" - await store.write_bytes_async("async_test_file.bin", test_data) - result = await store.read_bytes_async("async_test_file.bin") + await store.write_bytes_async("async_test_file.bin", test_data) + result = await store.read_bytes_async("async_test_file.bin") - assert result == test_data + assert result == test_data @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -async def test_async_write_and_read_text() -> None: +async def test_async_write_and_read_text(tmp_path: Path) -> None: """Test async write and read text operations.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) - test_text = "async test text content\nwith multiple lines" + store = FSSpecBackend("file", base_path=str(tmp_path)) + test_text = "async test text content\nwith multiple lines" - await store.write_text_async("async_test_file.txt", test_text) - result = await store.read_text_async("async_test_file.txt") + await store.write_text_async("async_test_file.txt", test_text) + result = await store.read_text_async("async_test_file.txt") - assert result == test_text + assert result == test_text @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -async def test_async_exists() -> None: +async def test_async_exists(tmp_path: Path) -> None: """Test async exists operation.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - assert not await store.exists_async("async_nonexistent.txt") + assert not await store.exists_async("async_nonexistent.txt") - await store.write_text_async("async_existing.txt", "content") - assert await store.exists_async("async_existing.txt") + await store.write_text_async("async_existing.txt", "content") + assert await store.exists_async("async_existing.txt") @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -async def test_async_delete() -> None: +async def test_async_delete(tmp_path: Path) -> None: """Test async delete operation.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - await store.write_text_async("async_to_delete.txt", "content") - assert await store.exists_async("async_to_delete.txt") + await store.write_text_async("async_to_delete.txt", "content") + assert await store.exists_async("async_to_delete.txt") - await store.delete_async("async_to_delete.txt") - assert not await store.exists_async("async_to_delete.txt") + await store.delete_async("async_to_delete.txt") + assert not await store.exists_async("async_to_delete.txt") @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -async def test_async_copy() -> None: +async def test_async_copy(tmp_path: Path) -> None: """Test async copy operation.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) - original_content = "async original content" + store = FSSpecBackend("file", base_path=str(tmp_path)) + original_content = "async original content" - await store.write_text_async("async_original.txt", original_content) - await store.copy_async("async_original.txt", "async_copied.txt") + await store.write_text_async("async_original.txt", original_content) + await store.copy_async("async_original.txt", "async_copied.txt") - assert await store.exists_async("async_copied.txt") - assert await store.read_text_async("async_copied.txt") == original_content + assert await store.exists_async("async_copied.txt") + assert await store.read_text_async("async_copied.txt") == original_content @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -async def test_async_move() -> None: +async def test_async_move(tmp_path: Path) -> None: """Test async move operation.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) - original_content = "async content to move" + store = FSSpecBackend("file", base_path=str(tmp_path)) + original_content = "async content to move" - await store.write_text_async("async_original.txt", original_content) - await store.move_async("async_original.txt", "async_moved.txt") + await store.write_text_async("async_original.txt", original_content) + await store.move_async("async_original.txt", "async_moved.txt") - assert not await store.exists_async("async_original.txt") - assert await store.exists_async("async_moved.txt") - assert await store.read_text_async("async_moved.txt") == original_content + assert not await store.exists_async("async_original.txt") + assert await store.exists_async("async_moved.txt") + assert await store.read_text_async("async_moved.txt") == original_content @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -async def test_async_list_objects() -> None: +async def test_async_list_objects(tmp_path: Path) -> None: """Test async list_objects operation.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - # Create test files - await store.write_text_async("async_file1.txt", "content1") - await store.write_text_async("async_file2.txt", "content2") - await store.write_text_async("async_subdir/file3.txt", "content3") + # Create test files + await store.write_text_async("async_file1.txt", "content1") + await store.write_text_async("async_file2.txt", "content2") + await store.write_text_async("async_subdir/file3.txt", "content3") - # List all objects - all_objects = await store.list_objects_async() - assert any("file1.txt" in obj for obj in all_objects) - assert any("file2.txt" in obj for obj in all_objects) - assert any("file3.txt" in obj for obj in all_objects) + # List all objects + all_objects = await store.list_objects_async() + assert any("file1.txt" in obj for obj in all_objects) + assert any("file2.txt" in obj for obj in all_objects) + assert any("file3.txt" in obj for obj in all_objects) @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -async def test_async_get_metadata() -> None: +async def test_async_get_metadata(tmp_path: Path) -> None: """Test async get_metadata operation.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) - test_content = "async test content for metadata" + store = FSSpecBackend("file", base_path=str(tmp_path)) + test_content = "async test content for metadata" - await store.write_text_async("async_test_file.txt", test_content) - metadata = await store.get_metadata_async("async_test_file.txt") + await store.write_text_async("async_test_file.txt", test_content) + metadata = await store.get_metadata_async("async_test_file.txt") - assert "size" in metadata - assert "exists" in metadata - assert metadata["exists"] is True + assert "size" in metadata + assert "exists" in metadata + assert metadata["exists"] is True @pytest.mark.skipif(not FSSPEC_INSTALLED or not PYARROW_INSTALLED, reason="fsspec or PyArrow not installed") -async def test_async_write_and_read_arrow() -> None: +async def test_async_write_and_read_arrow(tmp_path: Path) -> None: """Test async write and read Arrow table operations.""" import pyarrow as pa from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - # Create test Arrow table - data: dict[str, Any] = { - "id": [1, 2, 3, 4], - "name": ["Alice", "Bob", "Charlie", "David"], - "score": [95.5, 87.0, 92.3, 89.7], - } - table = pa.table(data) + # Create test Arrow table + data: dict[str, Any] = { + "id": [1, 2, 3, 4], + "name": ["Alice", "Bob", "Charlie", "David"], + "score": [95.5, 87.0, 92.3, 89.7], + } + table = pa.table(data) - await store.write_arrow_async("async_test_data.parquet", table) - result = await store.read_arrow_async("async_test_data.parquet") + await store.write_arrow_async("async_test_data.parquet", table) + result = await store.read_arrow_async("async_test_data.parquet") - assert result.equals(table) + assert result.equals(table) @pytest.mark.skipif(not FSSPEC_INSTALLED or not PYARROW_INSTALLED, reason="fsspec or PyArrow not installed") -async def test_async_stream_arrow() -> None: +async def test_async_stream_arrow(tmp_path: Path) -> None: """Test async stream Arrow record batches.""" import pyarrow as pa from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - # Create test Arrow table - data: dict[str, Any] = {"id": [1, 2, 3, 4, 5, 6], "value": ["a", "b", "c", "d", "e", "f"]} - table = pa.table(data) + # Create test Arrow table + data: dict[str, Any] = {"id": [1, 2, 3, 4, 5, 6], "value": ["a", "b", "c", "d", "e", "f"]} + table = pa.table(data) - await store.write_arrow_async("async_stream_test.parquet", table) + await store.write_arrow_async("async_stream_test.parquet", table) - # Stream record batches - batches = [batch async for batch in store.stream_arrow_async("async_stream_test.parquet")] + # Stream record batches + batches = [batch async for batch in store.stream_arrow_async("async_stream_test.parquet")] - assert len(batches) > 0 + assert len(batches) > 0 - # Verify we can read the data - reconstructed = pa.Table.from_batches(batches) - assert reconstructed.equals(table) + # Verify we can read the data + reconstructed = pa.Table.from_batches(batches) + assert reconstructed.equals(table) @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -async def test_async_sign() -> None: +async def test_async_sign(tmp_path: Path) -> None: """Test async sign returns URI for local files.""" from sqlspec.storage.backends.fsspec import FSSpecBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - await store.write_text_async("async_test.txt", "content") - signed_url = await store.sign_async("async_test.txt") + await store.write_text_async("async_test.txt", "content") + signed_url = await store.sign_async("async_test.txt") - assert "async_test.txt" in signed_url + assert "async_test.txt" in signed_url def test_fsspec_operations_without_fsspec() -> None: @@ -484,21 +459,20 @@ def test_fsspec_operations_without_fsspec() -> None: @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -def test_arrow_operations_without_pyarrow() -> None: +def test_arrow_operations_without_pyarrow(tmp_path: Path) -> None: """Test Arrow operations raise proper error without PyArrow.""" from sqlspec.storage.backends.fsspec import FSSpecBackend if PYARROW_INSTALLED: pytest.skip("PyArrow is installed") - with tempfile.TemporaryDirectory() as temp_dir: - store = FSSpecBackend("file", base_path=temp_dir) + store = FSSpecBackend("file", base_path=str(tmp_path)) - with pytest.raises(MissingDependencyError, match="pyarrow"): - store.read_arrow("test.parquet") + with pytest.raises(MissingDependencyError, match="pyarrow"): + store.read_arrow("test.parquet") - with pytest.raises(MissingDependencyError, match="pyarrow"): - store.write_arrow("test.parquet", None) # type: ignore + with pytest.raises(MissingDependencyError, match="pyarrow"): + store.write_arrow("test.parquet", None) # type: ignore - with pytest.raises(MissingDependencyError, match="pyarrow"): - list(store.stream_arrow("*.parquet")) + with pytest.raises(MissingDependencyError, match="pyarrow"): + list(store.stream_arrow("*.parquet")) diff --git a/tests/unit/test_storage/test_local_store.py b/tests/unit/test_storage/test_local_store.py index c226ba43..84f99b45 100644 --- a/tests/unit/test_storage/test_local_store.py +++ b/tests/unit/test_storage/test_local_store.py @@ -2,7 +2,6 @@ # pyright: reportPrivateImportUsage = false, reportPrivateUsage = false """Unit tests for LocalStore backend.""" -import tempfile from pathlib import Path from typing import Any @@ -14,18 +13,16 @@ from sqlspec.typing import PYARROW_INSTALLED -def test_init_with_file_uri() -> None: +def test_init_with_file_uri(tmp_path: Path) -> None: """Test initialization with file:// URI.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(f"file://{temp_dir}") - assert store.base_path == Path(temp_dir).resolve() + store = LocalStore(f"file://{tmp_path}") + assert store.base_path == tmp_path.resolve() -def test_init_with_path_string() -> None: +def test_init_with_path_string(tmp_path: Path) -> None: """Test initialization with plain path string.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) - assert store.base_path == Path(temp_dir).resolve() + store = LocalStore(str(tmp_path)) + assert store.base_path == tmp_path.resolve() def test_init_empty_defaults_to_cwd() -> None: @@ -34,458 +31,426 @@ def test_init_empty_defaults_to_cwd() -> None: assert store.base_path == Path.cwd() -def test_write_and_read_bytes() -> None: +def test_write_and_read_bytes(tmp_path: Path) -> None: """Test write and read bytes operations.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) - test_data = b"test data content" + store = LocalStore(str(tmp_path)) + test_data = b"test data content" - store.write_bytes("test_file.bin", test_data) - result = store.read_bytes("test_file.bin") + store.write_bytes("test_file.bin", test_data) + result = store.read_bytes("test_file.bin") - assert result == test_data + assert result == test_data -def test_write_and_read_text() -> None: +def test_write_and_read_text(tmp_path: Path) -> None: """Test write and read text operations.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) - test_text = "test text content\nwith multiple lines" + store = LocalStore(str(tmp_path)) + test_text = "test text content\nwith multiple lines" - store.write_text("test_file.txt", test_text) - result = store.read_text("test_file.txt") + store.write_text("test_file.txt", test_text) + result = store.read_text("test_file.txt") - assert result == test_text + assert result == test_text -def test_write_and_read_text_custom_encoding() -> None: +def test_write_and_read_text_custom_encoding(tmp_path: Path) -> None: """Test write and read text with custom encoding.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) - test_text = "test with ünicode" + store = LocalStore(str(tmp_path)) + test_text = "test with ünicode" - store.write_text("test_file.txt", test_text, encoding="latin-1") - result = store.read_text("test_file.txt", encoding="latin-1") + store.write_text("test_file.txt", test_text, encoding="latin-1") + result = store.read_text("test_file.txt", encoding="latin-1") - assert result == test_text + assert result == test_text -def test_exists() -> None: +def test_exists(tmp_path: Path) -> None: """Test exists operation.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - assert not store.exists("nonexistent.txt") + assert not store.exists("nonexistent.txt") - store.write_text("existing.txt", "content") - assert store.exists("existing.txt") + store.write_text("existing.txt", "content") + assert store.exists("existing.txt") -def test_delete() -> None: +def test_delete(tmp_path: Path) -> None: """Test delete operation.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - store.write_text("to_delete.txt", "content") - assert store.exists("to_delete.txt") + store.write_text("to_delete.txt", "content") + assert store.exists("to_delete.txt") - store.delete("to_delete.txt") - assert not store.exists("to_delete.txt") + store.delete("to_delete.txt") + assert not store.exists("to_delete.txt") -def test_copy() -> None: +def test_copy(tmp_path: Path) -> None: """Test copy operation.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) - original_content = "original content" + store = LocalStore(str(tmp_path)) + original_content = "original content" - store.write_text("original.txt", original_content) - store.copy("original.txt", "copied.txt") + store.write_text("original.txt", original_content) + store.copy("original.txt", "copied.txt") - assert store.exists("copied.txt") - assert store.read_text("copied.txt") == original_content + assert store.exists("copied.txt") + assert store.read_text("copied.txt") == original_content -def test_move() -> None: +def test_move(tmp_path: Path) -> None: """Test move operation.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) - original_content = "content to move" + store = LocalStore(str(tmp_path)) + original_content = "content to move" - store.write_text("original.txt", original_content) - store.move("original.txt", "moved.txt") + store.write_text("original.txt", original_content) + store.move("original.txt", "moved.txt") - assert not store.exists("original.txt") - assert store.exists("moved.txt") - assert store.read_text("moved.txt") == original_content + assert not store.exists("original.txt") + assert store.exists("moved.txt") + assert store.read_text("moved.txt") == original_content -def test_list_objects() -> None: +def test_list_objects(tmp_path: Path) -> None: """Test list_objects operation.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - # Create test files - store.write_text("file1.txt", "content1") - store.write_text("file2.txt", "content2") - store.write_text("subdir/file3.txt", "content3") + # Create test files + store.write_text("file1.txt", "content1") + store.write_text("file2.txt", "content2") + store.write_text("subdir/file3.txt", "content3") - # List all objects - all_objects = store.list_objects() - assert "file1.txt" in all_objects - assert "file2.txt" in all_objects - assert "subdir/file3.txt" in all_objects + # List all objects + all_objects = store.list_objects() + assert "file1.txt" in all_objects + assert "file2.txt" in all_objects + assert "subdir/file3.txt" in all_objects -def test_list_objects_with_prefix() -> None: +def test_list_objects_with_prefix(tmp_path: Path) -> None: """Test list_objects with prefix filtering.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - # Create test files - store.write_text("prefix_file1.txt", "content1") - store.write_text("prefix_file2.txt", "content2") - store.write_text("other_file.txt", "content3") + # Create test files + store.write_text("prefix_file1.txt", "content1") + store.write_text("prefix_file2.txt", "content2") + store.write_text("other_file.txt", "content3") - # List with prefix - prefixed_objects = store.list_objects(prefix="prefix_") - assert "prefix_file1.txt" in prefixed_objects - assert "prefix_file2.txt" in prefixed_objects - assert "other_file.txt" not in prefixed_objects + # List with prefix + prefixed_objects = store.list_objects(prefix="prefix_") + assert "prefix_file1.txt" in prefixed_objects + assert "prefix_file2.txt" in prefixed_objects + assert "other_file.txt" not in prefixed_objects -def test_glob() -> None: +def test_glob(tmp_path: Path) -> None: """Test glob pattern matching.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - # Create test files - store.write_text("test1.sql", "SELECT 1") - store.write_text("test2.sql", "SELECT 2") - store.write_text("config.json", "{}") - store.write_text("subdir/test3.sql", "SELECT 3") + # Create test files + store.write_text("test1.sql", "SELECT 1") + store.write_text("test2.sql", "SELECT 2") + store.write_text("config.json", "{}") + store.write_text("subdir/test3.sql", "SELECT 3") - # Test glob patterns - sql_files = store.glob("*.sql") - assert "test1.sql" in sql_files - assert "test2.sql" in sql_files - assert "config.json" not in sql_files + # Test glob patterns + sql_files = store.glob("*.sql") + assert "test1.sql" in sql_files + assert "test2.sql" in sql_files + assert "config.json" not in sql_files -def test_get_metadata() -> None: +def test_get_metadata(tmp_path: Path) -> None: """Test get_metadata operation.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) - test_content = "test content for metadata" + store = LocalStore(str(tmp_path)) + test_content = "test content for metadata" - store.write_text("test_file.txt", test_content) - metadata = store.get_metadata("test_file.txt") + store.write_text("test_file.txt", test_content) + metadata = store.get_metadata("test_file.txt") - assert "size" in metadata - assert "modified" in metadata - assert metadata["size"] == len(test_content.encode()) + assert "size" in metadata + assert "modified" in metadata + assert metadata["size"] == len(test_content.encode()) -def test_is_object_and_is_path() -> None: +def test_is_object_and_is_path(tmp_path: Path) -> None: """Test is_object and is_path operations.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - store.write_text("file.txt", "content") - (Path(temp_dir) / "subdir").mkdir() + store.write_text("file.txt", "content") + (tmp_path / "subdir").mkdir() - assert store.is_object("file.txt") - assert not store.is_object("subdir") - assert not store.is_path("file.txt") - assert store.is_path("subdir") + assert store.is_object("file.txt") + assert not store.is_object("subdir") + assert not store.is_path("file.txt") + assert store.is_path("subdir") @pytest.mark.skipif(not PYARROW_INSTALLED, reason="PyArrow not installed") -def test_write_and_read_arrow() -> None: +def test_write_and_read_arrow(tmp_path: Path) -> None: """Test write and read Arrow table operations.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - # Create test Arrow table - data: dict[str, Any] = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"], "score": [95.5, 87.0, 92.3]} - table = pa.table(data) + # Create test Arrow table + data: dict[str, Any] = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"], "score": [95.5, 87.0, 92.3]} + table = pa.table(data) - store.write_arrow("test_data.parquet", table) - result = store.read_arrow("test_data.parquet") + store.write_arrow("test_data.parquet", table) + result = store.read_arrow("test_data.parquet") - assert result.equals(table) + assert result.equals(table) @pytest.mark.skipif(not PYARROW_INSTALLED, reason="PyArrow not installed") -def test_stream_arrow() -> None: +def test_stream_arrow(tmp_path: Path) -> None: """Test stream Arrow record batches.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - # Create test Arrow table - data: dict[str, Any] = {"id": [1, 2, 3, 4, 5], "value": ["a", "b", "c", "d", "e"]} - table = pa.table(data) + # Create test Arrow table + data: dict[str, Any] = {"id": [1, 2, 3, 4, 5], "value": ["a", "b", "c", "d", "e"]} + table = pa.table(data) - store.write_arrow("stream_test.parquet", table) + store.write_arrow("stream_test.parquet", table) - # Stream record batches - batches = list(store.stream_arrow("stream_test.parquet")) - assert len(batches) > 0 + # Stream record batches + batches = list(store.stream_arrow("stream_test.parquet")) + assert len(batches) > 0 - # Verify we can read the data - reconstructed = pa.Table.from_batches(batches) - assert reconstructed.equals(table) + # Verify we can read the data + reconstructed = pa.Table.from_batches(batches) + assert reconstructed.equals(table) -def test_sign_returns_file_uri() -> None: +def test_sign_returns_file_uri(tmp_path: Path) -> None: """Test sign returns file:// URI for local files.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - store.write_text("test.txt", "content") - signed_url = store.sign("test.txt") + store.write_text("test.txt", "content") + signed_url = store.sign("test.txt") - assert signed_url.startswith("file://") - assert "test.txt" in signed_url + assert signed_url.startswith("file://") + assert "test.txt" in signed_url -def test_sign_with_options() -> None: +def test_sign_with_options(tmp_path: Path) -> None: """Test sign with expires_in and for_upload options.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - store.write_text("test.txt", "content") + store.write_text("test.txt", "content") - # Options are ignored for local files but should not error - signed_url = store.sign("test.txt", expires_in=7200, for_upload=True) - assert signed_url.startswith("file://") + # Options are ignored for local files but should not error + signed_url = store.sign("test.txt", expires_in=7200, for_upload=True) + assert signed_url.startswith("file://") -def test_resolve_path_absolute() -> None: +def test_resolve_path_absolute(tmp_path: Path) -> None: """Test path resolution with absolute paths.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - # Absolute path should be returned as-is - test_path = Path(temp_dir) / "test.txt" - store.write_text("test.txt", "content") + # Absolute path should be returned as-is + test_path = tmp_path / "test.txt" + store.write_text("test.txt", "content") - resolved = store._resolve_path(str(test_path)) - assert resolved == test_path + resolved = store._resolve_path(str(test_path)) + assert resolved == test_path -def test_resolve_path_relative() -> None: +def test_resolve_path_relative(tmp_path: Path) -> None: """Test path resolution with relative paths.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - resolved = store._resolve_path("subdir/file.txt") - expected = Path(temp_dir).resolve() / "subdir" / "file.txt" - assert resolved == expected + resolved = store._resolve_path("subdir/file.txt") + expected = tmp_path.resolve() / "subdir" / "file.txt" + assert resolved == expected -def test_nested_directory_operations() -> None: +def test_nested_directory_operations(tmp_path: Path) -> None: """Test operations with nested directories.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - # Write to nested path - store.write_text("level1/level2/file.txt", "nested content") - assert store.exists("level1/level2/file.txt") - assert store.read_text("level1/level2/file.txt") == "nested content" + # Write to nested path + store.write_text("level1/level2/file.txt", "nested content") + assert store.exists("level1/level2/file.txt") + assert store.read_text("level1/level2/file.txt") == "nested content" - # List should include nested files - objects = store.list_objects() - assert "level1/level2/file.txt" in objects + # List should include nested files + objects = store.list_objects() + assert "level1/level2/file.txt" in objects -def test_file_not_found_errors() -> None: +def test_file_not_found_errors(tmp_path: Path) -> None: """Test operations on non-existent files raise appropriate errors.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - with pytest.raises(FileNotFoundError): - store.read_bytes("nonexistent.bin") + with pytest.raises(FileNotFoundError): + store.read_bytes("nonexistent.bin") - with pytest.raises(FileNotFoundError): - store.read_text("nonexistent.txt") + with pytest.raises(FileNotFoundError): + store.read_text("nonexistent.txt") # Async tests -async def test_async_write_and_read_bytes() -> None: +async def test_async_write_and_read_bytes(tmp_path: Path) -> None: """Test async write and read bytes operations.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) - test_data = b"async test data content" + store = LocalStore(str(tmp_path)) + test_data = b"async test data content" - await store.write_bytes_async("async_test_file.bin", test_data) - result = await store.read_bytes_async("async_test_file.bin") + await store.write_bytes_async("async_test_file.bin", test_data) + result = await store.read_bytes_async("async_test_file.bin") - assert result == test_data + assert result == test_data -async def test_async_write_and_read_text() -> None: +async def test_async_write_and_read_text(tmp_path: Path) -> None: """Test async write and read text operations.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) - test_text = "async test text content\nwith multiple lines" + store = LocalStore(str(tmp_path)) + test_text = "async test text content\nwith multiple lines" - await store.write_text_async("async_test_file.txt", test_text) - result = await store.read_text_async("async_test_file.txt") + await store.write_text_async("async_test_file.txt", test_text) + result = await store.read_text_async("async_test_file.txt") - assert result == test_text + assert result == test_text -async def test_async_exists() -> None: +async def test_async_exists(tmp_path: Path) -> None: """Test async exists operation.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - assert not await store.exists_async("async_nonexistent.txt") + assert not await store.exists_async("async_nonexistent.txt") - await store.write_text_async("async_existing.txt", "content") - assert await store.exists_async("async_existing.txt") + await store.write_text_async("async_existing.txt", "content") + assert await store.exists_async("async_existing.txt") -async def test_async_delete() -> None: +async def test_async_delete(tmp_path: Path) -> None: """Test async delete operation.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - await store.write_text_async("async_to_delete.txt", "content") - assert await store.exists_async("async_to_delete.txt") + await store.write_text_async("async_to_delete.txt", "content") + assert await store.exists_async("async_to_delete.txt") - await store.delete_async("async_to_delete.txt") - assert not await store.exists_async("async_to_delete.txt") + await store.delete_async("async_to_delete.txt") + assert not await store.exists_async("async_to_delete.txt") -async def test_async_copy() -> None: +async def test_async_copy(tmp_path: Path) -> None: """Test async copy operation.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) - original_content = "async original content" + store = LocalStore(str(tmp_path)) + original_content = "async original content" - await store.write_text_async("async_original.txt", original_content) - await store.copy_async("async_original.txt", "async_copied.txt") + await store.write_text_async("async_original.txt", original_content) + await store.copy_async("async_original.txt", "async_copied.txt") - assert await store.exists_async("async_copied.txt") - assert await store.read_text_async("async_copied.txt") == original_content + assert await store.exists_async("async_copied.txt") + assert await store.read_text_async("async_copied.txt") == original_content -async def test_async_move() -> None: +async def test_async_move(tmp_path: Path) -> None: """Test async move operation.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) - original_content = "async content to move" + store = LocalStore(str(tmp_path)) + original_content = "async content to move" - await store.write_text_async("async_original.txt", original_content) - await store.move_async("async_original.txt", "async_moved.txt") + await store.write_text_async("async_original.txt", original_content) + await store.move_async("async_original.txt", "async_moved.txt") - assert not await store.exists_async("async_original.txt") - assert await store.exists_async("async_moved.txt") - assert await store.read_text_async("async_moved.txt") == original_content + assert not await store.exists_async("async_original.txt") + assert await store.exists_async("async_moved.txt") + assert await store.read_text_async("async_moved.txt") == original_content -async def test_async_list_objects() -> None: +async def test_async_list_objects(tmp_path: Path) -> None: """Test async list_objects operation.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - # Create test files - await store.write_text_async("async_file1.txt", "content1") - await store.write_text_async("async_file2.txt", "content2") - await store.write_text_async("async_subdir/file3.txt", "content3") + # Create test files + await store.write_text_async("async_file1.txt", "content1") + await store.write_text_async("async_file2.txt", "content2") + await store.write_text_async("async_subdir/file3.txt", "content3") - # List all objects - all_objects = await store.list_objects_async() - assert "async_file1.txt" in all_objects - assert "async_file2.txt" in all_objects - assert "async_subdir/file3.txt" in all_objects + # List all objects + all_objects = await store.list_objects_async() + assert "async_file1.txt" in all_objects + assert "async_file2.txt" in all_objects + assert "async_subdir/file3.txt" in all_objects -async def test_async_get_metadata() -> None: +async def test_async_get_metadata(tmp_path: Path) -> None: """Test async get_metadata operation.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) - test_content = "async test content for metadata" + store = LocalStore(str(tmp_path)) + test_content = "async test content for metadata" - await store.write_text_async("async_test_file.txt", test_content) - metadata = await store.get_metadata_async("async_test_file.txt") + await store.write_text_async("async_test_file.txt", test_content) + metadata = await store.get_metadata_async("async_test_file.txt") - assert "size" in metadata - assert "modified" in metadata - assert metadata["size"] == len(test_content.encode()) + assert "size" in metadata + assert "modified" in metadata + assert metadata["size"] == len(test_content.encode()) @pytest.mark.skipif(not PYARROW_INSTALLED, reason="PyArrow not installed") -async def test_async_write_and_read_arrow() -> None: +async def test_async_write_and_read_arrow(tmp_path: Path) -> None: """Test async write and read Arrow table operations.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - # Create test Arrow table - data: dict[str, Any] = { - "id": [1, 2, 3, 4], - "name": ["Alice", "Bob", "Charlie", "David"], - "score": [95.5, 87.0, 92.3, 89.7], - } - table = pa.table(data) + # Create test Arrow table + data: dict[str, Any] = { + "id": [1, 2, 3, 4], + "name": ["Alice", "Bob", "Charlie", "David"], + "score": [95.5, 87.0, 92.3, 89.7], + } + table = pa.table(data) - await store.write_arrow_async("async_test_data.parquet", table) - result = await store.read_arrow_async("async_test_data.parquet") + await store.write_arrow_async("async_test_data.parquet", table) + result = await store.read_arrow_async("async_test_data.parquet") - assert result.equals(table) + assert result.equals(table) @pytest.mark.skipif(not PYARROW_INSTALLED, reason="PyArrow not installed") -async def test_async_stream_arrow() -> None: +async def test_async_stream_arrow(tmp_path: Path) -> None: """Test async stream Arrow record batches.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - # Create test Arrow table - data: dict[str, Any] = {"id": [1, 2, 3, 4, 5, 6], "value": ["a", "b", "c", "d", "e", "f"]} - table = pa.table(data) + # Create test Arrow table + data: dict[str, Any] = {"id": [1, 2, 3, 4, 5, 6], "value": ["a", "b", "c", "d", "e", "f"]} + table = pa.table(data) - await store.write_arrow_async("async_stream_test.parquet", table) + await store.write_arrow_async("async_stream_test.parquet", table) - # Stream record batches - batches = [batch async for batch in store.stream_arrow_async("async_stream_test.parquet")] + # Stream record batches + batches = [batch async for batch in store.stream_arrow_async("async_stream_test.parquet")] - assert len(batches) > 0 + assert len(batches) > 0 - # Verify we can read the data - reconstructed = pa.Table.from_batches(batches) - assert reconstructed.equals(table) + # Verify we can read the data + reconstructed = pa.Table.from_batches(batches) + assert reconstructed.equals(table) -async def test_async_sign() -> None: +async def test_async_sign(tmp_path: Path) -> None: """Test async sign returns file:// URI for local files.""" - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - await store.write_text_async("async_test.txt", "content") - signed_url = await store.sign_async("async_test.txt") + await store.write_text_async("async_test.txt", "content") + signed_url = await store.sign_async("async_test.txt") - assert signed_url.startswith("file://") - assert "async_test.txt" in signed_url + assert signed_url.startswith("file://") + assert "async_test.txt" in signed_url -def test_arrow_operations_without_pyarrow() -> None: +def test_arrow_operations_without_pyarrow(tmp_path: Path) -> None: """Test Arrow operations raise proper error without PyArrow.""" if PYARROW_INSTALLED: pytest.skip("PyArrow is installed") - with tempfile.TemporaryDirectory() as temp_dir: - store = LocalStore(temp_dir) + store = LocalStore(str(tmp_path)) - with pytest.raises(MissingDependencyError, match="pyarrow"): - store.read_arrow("test.parquet") + with pytest.raises(MissingDependencyError, match="pyarrow"): + store.read_arrow("test.parquet") - with pytest.raises(MissingDependencyError, match="pyarrow"): - store.write_arrow("test.parquet", None) # type: ignore + with pytest.raises(MissingDependencyError, match="pyarrow"): + store.write_arrow("test.parquet", None) # type: ignore - with pytest.raises(MissingDependencyError, match="pyarrow"): - list(store.stream_arrow("*.parquet")) + with pytest.raises(MissingDependencyError, match="pyarrow"): + list(store.stream_arrow("*.parquet")) diff --git a/tests/unit/test_storage/test_obstore_backend.py b/tests/unit/test_storage/test_obstore_backend.py index 562a00a3..ce61febc 100644 --- a/tests/unit/test_storage/test_obstore_backend.py +++ b/tests/unit/test_storage/test_obstore_backend.py @@ -1,6 +1,6 @@ """Unit tests for ObStoreBackend.""" -import tempfile +from pathlib import Path from typing import Any import pytest @@ -13,250 +13,235 @@ @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_init_with_file_uri() -> None: +def test_init_with_file_uri(tmp_path: Path) -> None: """Test initialization with file:// URI.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") - assert store.base_path == "" + store = ObStoreBackend(f"file://{tmp_path}") + assert store.base_path == "" @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_from_config() -> None: +def test_from_config(tmp_path: Path) -> None: """Test from_config class method.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - data_dir = f"{temp_dir}/data" - config = {"store_uri": f"file://{data_dir}", "store_options": {}} - store = ObStoreBackend.from_config(config) - assert store.base_path == "" + data_dir = f"{tmp_path}/data" + config = {"store_uri": f"file://{data_dir}", "store_options": {}} + store = ObStoreBackend.from_config(config) + assert store.base_path == "" @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_write_and_read_bytes() -> None: +def test_write_and_read_bytes(tmp_path: Path) -> None: """Test write and read bytes operations.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") - test_data = b"test data content" + store = ObStoreBackend(f"file://{tmp_path}") + test_data = b"test data content" - store.write_bytes("test_file.bin", test_data) - result = store.read_bytes("test_file.bin") + store.write_bytes("test_file.bin", test_data) + result = store.read_bytes("test_file.bin") - assert result == test_data + assert result == test_data @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_write_and_read_text() -> None: +def test_write_and_read_text(tmp_path: Path) -> None: """Test write and read text operations.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") - test_text = "test text content\nwith multiple lines" + store = ObStoreBackend(f"file://{tmp_path}") + test_text = "test text content\nwith multiple lines" - store.write_text("test_file.txt", test_text) - result = store.read_text("test_file.txt") + store.write_text("test_file.txt", test_text) + result = store.read_text("test_file.txt") - assert result == test_text + assert result == test_text @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_exists() -> None: +def test_exists(tmp_path: Path) -> None: """Test exists operation.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - assert not store.exists("nonexistent.txt") + assert not store.exists("nonexistent.txt") - store.write_text("existing.txt", "content") - assert store.exists("existing.txt") + store.write_text("existing.txt", "content") + assert store.exists("existing.txt") @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_delete() -> None: +def test_delete(tmp_path: Path) -> None: """Test delete operation.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - store.write_text("to_delete.txt", "content") - assert store.exists("to_delete.txt") + store.write_text("to_delete.txt", "content") + assert store.exists("to_delete.txt") - store.delete("to_delete.txt") - assert not store.exists("to_delete.txt") + store.delete("to_delete.txt") + assert not store.exists("to_delete.txt") @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_copy() -> None: +def test_copy(tmp_path: Path) -> None: """Test copy operation.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") - original_content = "original content" + store = ObStoreBackend(f"file://{tmp_path}") + original_content = "original content" - store.write_text("original.txt", original_content) - store.copy("original.txt", "copied.txt") + store.write_text("original.txt", original_content) + store.copy("original.txt", "copied.txt") - assert store.exists("copied.txt") - assert store.read_text("copied.txt") == original_content + assert store.exists("copied.txt") + assert store.read_text("copied.txt") == original_content @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_move() -> None: +def test_move(tmp_path: Path) -> None: """Test move operation.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") - original_content = "content to move" + store = ObStoreBackend(f"file://{tmp_path}") + original_content = "content to move" - store.write_text("original.txt", original_content) - store.move("original.txt", "moved.txt") + store.write_text("original.txt", original_content) + store.move("original.txt", "moved.txt") - assert not store.exists("original.txt") - assert store.exists("moved.txt") - assert store.read_text("moved.txt") == original_content + assert not store.exists("original.txt") + assert store.exists("moved.txt") + assert store.read_text("moved.txt") == original_content @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_list_objects() -> None: +def test_list_objects(tmp_path: Path) -> None: """Test list_objects operation.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - # Create test files - store.write_text("file1.txt", "content1") - store.write_text("file2.txt", "content2") - store.write_text("subdir/file3.txt", "content3") + # Create test files + store.write_text("file1.txt", "content1") + store.write_text("file2.txt", "content2") + store.write_text("subdir/file3.txt", "content3") - # List all objects - all_objects = store.list_objects() - assert any("file1.txt" in obj for obj in all_objects) - assert any("file2.txt" in obj for obj in all_objects) - assert any("file3.txt" in obj for obj in all_objects) + # List all objects + all_objects = store.list_objects() + assert any("file1.txt" in obj for obj in all_objects) + assert any("file2.txt" in obj for obj in all_objects) + assert any("file3.txt" in obj for obj in all_objects) @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_glob() -> None: +def test_glob(tmp_path: Path) -> None: """Test glob pattern matching.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - # Create test files - store.write_text("test1.sql", "SELECT 1") - store.write_text("test2.sql", "SELECT 2") - store.write_text("config.json", "{}") + # Create test files + store.write_text("test1.sql", "SELECT 1") + store.write_text("test2.sql", "SELECT 2") + store.write_text("config.json", "{}") - # Test glob patterns - sql_files = store.glob("*.sql") - assert any("test1.sql" in obj for obj in sql_files) - assert any("test2.sql" in obj for obj in sql_files) - assert not any("config.json" in obj for obj in sql_files) + # Test glob patterns + sql_files = store.glob("*.sql") + assert any("test1.sql" in obj for obj in sql_files) + assert any("test2.sql" in obj for obj in sql_files) + assert not any("config.json" in obj for obj in sql_files) @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_get_metadata() -> None: +def test_get_metadata(tmp_path: Path) -> None: """Test get_metadata operation.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") - test_content = "test content for metadata" + store = ObStoreBackend(f"file://{tmp_path}") + test_content = "test content for metadata" - store.write_text("test_file.txt", test_content) - metadata = store.get_metadata("test_file.txt") + store.write_text("test_file.txt", test_content) + metadata = store.get_metadata("test_file.txt") - assert "exists" in metadata - assert metadata["exists"] is True + assert "exists" in metadata + assert metadata["exists"] is True @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_is_object_and_is_path() -> None: +def test_is_object_and_is_path(tmp_path: Path) -> None: """Test is_object and is_path operations.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - store.write_text("file.txt", "content") - # Create directory by writing file inside it - store.write_text("subdir/nested.txt", "content") + store.write_text("file.txt", "content") + # Create directory by writing file inside it + store.write_text("subdir/nested.txt", "content") - assert store.is_object("file.txt") - assert not store.is_object("subdir") - assert not store.is_path("file.txt") - assert store.is_path("subdir") + assert store.is_object("file.txt") + assert not store.is_object("subdir") + assert not store.is_path("file.txt") + assert store.is_path("subdir") @pytest.mark.skipif(not OBSTORE_INSTALLED or not PYARROW_INSTALLED, reason="obstore or PyArrow not installed") -def test_write_and_read_arrow() -> None: +def test_write_and_read_arrow(tmp_path: Path) -> None: """Test write and read Arrow table operations.""" import pyarrow as pa from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - # Create test Arrow table - data: dict[str, Any] = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"], "score": [95.5, 87.0, 92.3]} - table = pa.table(data) + # Create test Arrow table + data: dict[str, Any] = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"], "score": [95.5, 87.0, 92.3]} + table = pa.table(data) - store.write_arrow("test_data.parquet", table) - result = store.read_arrow("test_data.parquet") + store.write_arrow("test_data.parquet", table) + result = store.read_arrow("test_data.parquet") - assert result.equals(table) + assert result.equals(table) @pytest.mark.skipif(not OBSTORE_INSTALLED or not PYARROW_INSTALLED, reason="obstore or PyArrow not installed") -def test_stream_arrow() -> None: +def test_stream_arrow(tmp_path: Path) -> None: """Test stream Arrow record batches.""" import pyarrow as pa from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - # Create test Arrow table - data: dict[str, Any] = {"id": [1, 2, 3, 4, 5], "value": ["a", "b", "c", "d", "e"]} - table = pa.table(data) + # Create test Arrow table + data: dict[str, Any] = {"id": [1, 2, 3, 4, 5], "value": ["a", "b", "c", "d", "e"]} + table = pa.table(data) - store.write_arrow("stream_test.parquet", table) + store.write_arrow("stream_test.parquet", table) - # Stream record batches - batches = list(store.stream_arrow("stream_test.parquet")) - assert len(batches) > 0 + # Stream record batches + batches = list(store.stream_arrow("stream_test.parquet")) + assert len(batches) > 0 - # Verify we can read the data - reconstructed = pa.Table.from_batches(batches) - assert reconstructed.equals(table) + # Verify we can read the data + reconstructed = pa.Table.from_batches(batches) + assert reconstructed.equals(table) @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_sign_returns_uri() -> None: +def test_sign_returns_uri(tmp_path: Path) -> None: """Test sign returns URI for files.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - store.write_text("test.txt", "content") - signed_url = store.sign("test.txt") + store.write_text("test.txt", "content") + signed_url = store.sign("test.txt") - assert "test.txt" in signed_url + assert "test.txt" in signed_url def test_obstore_not_installed() -> None: @@ -272,195 +257,184 @@ def test_obstore_not_installed() -> None: @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -async def test_async_write_and_read_bytes() -> None: +async def test_async_write_and_read_bytes(tmp_path: Path) -> None: """Test async write and read bytes operations.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") - test_data = b"async test data content" + store = ObStoreBackend(f"file://{tmp_path}") + test_data = b"async test data content" - await store.write_bytes_async("async_test_file.bin", test_data) - result = await store.read_bytes_async("async_test_file.bin") + await store.write_bytes_async("async_test_file.bin", test_data) + result = await store.read_bytes_async("async_test_file.bin") - assert result == test_data + assert result == test_data @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -async def test_async_write_and_read_text() -> None: +async def test_async_write_and_read_text(tmp_path: Path) -> None: """Test async write and read text operations.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") - test_text = "async test text content\nwith multiple lines" + store = ObStoreBackend(f"file://{tmp_path}") + test_text = "async test text content\nwith multiple lines" - await store.write_text_async("async_test_file.txt", test_text) - result = await store.read_text_async("async_test_file.txt") + await store.write_text_async("async_test_file.txt", test_text) + result = await store.read_text_async("async_test_file.txt") - assert result == test_text + assert result == test_text @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -async def test_async_exists() -> None: +async def test_async_exists(tmp_path: Path) -> None: """Test async exists operation.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - assert not await store.exists_async("async_nonexistent.txt") + assert not await store.exists_async("async_nonexistent.txt") - await store.write_text_async("async_existing.txt", "content") - assert await store.exists_async("async_existing.txt") + await store.write_text_async("async_existing.txt", "content") + assert await store.exists_async("async_existing.txt") @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -async def test_async_delete() -> None: +async def test_async_delete(tmp_path: Path) -> None: """Test async delete operation.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - await store.write_text_async("async_to_delete.txt", "content") - assert await store.exists_async("async_to_delete.txt") + await store.write_text_async("async_to_delete.txt", "content") + assert await store.exists_async("async_to_delete.txt") - await store.delete_async("async_to_delete.txt") - assert not await store.exists_async("async_to_delete.txt") + await store.delete_async("async_to_delete.txt") + assert not await store.exists_async("async_to_delete.txt") @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -async def test_async_copy() -> None: +async def test_async_copy(tmp_path: Path) -> None: """Test async copy operation.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") - original_content = "async original content" + store = ObStoreBackend(f"file://{tmp_path}") + original_content = "async original content" - await store.write_text_async("async_original.txt", original_content) - await store.copy_async("async_original.txt", "async_copied.txt") + await store.write_text_async("async_original.txt", original_content) + await store.copy_async("async_original.txt", "async_copied.txt") - assert await store.exists_async("async_copied.txt") - assert await store.read_text_async("async_copied.txt") == original_content + assert await store.exists_async("async_copied.txt") + assert await store.read_text_async("async_copied.txt") == original_content @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -async def test_async_move() -> None: +async def test_async_move(tmp_path: Path) -> None: """Test async move operation.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") - original_content = "async content to move" + store = ObStoreBackend(f"file://{tmp_path}") + original_content = "async content to move" - await store.write_text_async("async_original.txt", original_content) - await store.move_async("async_original.txt", "async_moved.txt") + await store.write_text_async("async_original.txt", original_content) + await store.move_async("async_original.txt", "async_moved.txt") - assert not await store.exists_async("async_original.txt") - assert await store.exists_async("async_moved.txt") - assert await store.read_text_async("async_moved.txt") == original_content + assert not await store.exists_async("async_original.txt") + assert await store.exists_async("async_moved.txt") + assert await store.read_text_async("async_moved.txt") == original_content @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -async def test_async_list_objects() -> None: +async def test_async_list_objects(tmp_path: Path) -> None: """Test async list_objects operation.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - # Create test files - await store.write_text_async("async_file1.txt", "content1") - await store.write_text_async("async_file2.txt", "content2") - await store.write_text_async("async_subdir/file3.txt", "content3") + # Create test files + await store.write_text_async("async_file1.txt", "content1") + await store.write_text_async("async_file2.txt", "content2") + await store.write_text_async("async_subdir/file3.txt", "content3") - # List all objects - all_objects = await store.list_objects_async() - assert any("file1.txt" in obj for obj in all_objects) - assert any("file2.txt" in obj for obj in all_objects) - assert any("file3.txt" in obj for obj in all_objects) + # List all objects + all_objects = await store.list_objects_async() + assert any("file1.txt" in obj for obj in all_objects) + assert any("file2.txt" in obj for obj in all_objects) + assert any("file3.txt" in obj for obj in all_objects) @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -async def test_async_get_metadata() -> None: +async def test_async_get_metadata(tmp_path: Path) -> None: """Test async get_metadata operation.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") - test_content = "async test content for metadata" + store = ObStoreBackend(f"file://{tmp_path}") + test_content = "async test content for metadata" - await store.write_text_async("async_test_file.txt", test_content) - metadata = await store.get_metadata_async("async_test_file.txt") + await store.write_text_async("async_test_file.txt", test_content) + metadata = await store.get_metadata_async("async_test_file.txt") - assert "exists" in metadata - assert metadata["exists"] is True + assert "exists" in metadata + assert metadata["exists"] is True @pytest.mark.skipif(not OBSTORE_INSTALLED or not PYARROW_INSTALLED, reason="obstore or PyArrow not installed") -async def test_async_write_and_read_arrow() -> None: +async def test_async_write_and_read_arrow(tmp_path: Path) -> None: """Test async write and read Arrow table operations.""" import pyarrow as pa from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - # Create test Arrow table - data: dict[str, Any] = { - "id": [1, 2, 3, 4], - "name": ["Alice", "Bob", "Charlie", "David"], - "score": [95.5, 87.0, 92.3, 89.7], - } - table = pa.table(data) + # Create test Arrow table + data: dict[str, Any] = { + "id": [1, 2, 3, 4], + "name": ["Alice", "Bob", "Charlie", "David"], + "score": [95.5, 87.0, 92.3, 89.7], + } + table = pa.table(data) - await store.write_arrow_async("async_test_data.parquet", table) - result = await store.read_arrow_async("async_test_data.parquet") + await store.write_arrow_async("async_test_data.parquet", table) + result = await store.read_arrow_async("async_test_data.parquet") - assert result.equals(table) + assert result.equals(table) @pytest.mark.skipif(not OBSTORE_INSTALLED or not PYARROW_INSTALLED, reason="obstore or PyArrow not installed") -async def test_async_stream_arrow() -> None: +async def test_async_stream_arrow(tmp_path: Path) -> None: """Test async stream Arrow record batches.""" import pyarrow as pa from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - # Create test Arrow table - data: dict[str, Any] = {"id": [1, 2, 3, 4, 5, 6], "value": ["a", "b", "c", "d", "e", "f"]} - table = pa.table(data) + # Create test Arrow table + data: dict[str, Any] = {"id": [1, 2, 3, 4, 5, 6], "value": ["a", "b", "c", "d", "e", "f"]} + table = pa.table(data) - await store.write_arrow_async("async_stream_test.parquet", table) + await store.write_arrow_async("async_stream_test.parquet", table) - # Stream record batches - batches = [batch async for batch in store.stream_arrow_async("async_stream_test.parquet")] + # Stream record batches + batches = [batch async for batch in store.stream_arrow_async("async_stream_test.parquet")] - assert len(batches) > 0 + assert len(batches) > 0 - # Verify we can read the data - reconstructed = pa.Table.from_batches(batches) - assert reconstructed.equals(table) + # Verify we can read the data + reconstructed = pa.Table.from_batches(batches) + assert reconstructed.equals(table) @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -async def test_async_sign() -> None: +async def test_async_sign(tmp_path: Path) -> None: """Test async sign returns URI for files.""" from sqlspec.storage.backends.obstore import ObStoreBackend - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - await store.write_text_async("async_test.txt", "content") - signed_url = await store.sign_async("async_test.txt") + await store.write_text_async("async_test.txt", "content") + signed_url = await store.sign_async("async_test.txt") - assert "async_test.txt" in signed_url + assert "async_test.txt" in signed_url def test_obstore_operations_without_obstore() -> None: @@ -473,21 +447,20 @@ def test_obstore_operations_without_obstore() -> None: @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_arrow_operations_without_pyarrow() -> None: +def test_arrow_operations_without_pyarrow(tmp_path: Path) -> None: """Test Arrow operations raise proper error without PyArrow.""" from sqlspec.storage.backends.obstore import ObStoreBackend if PYARROW_INSTALLED: pytest.skip("PyArrow is installed") - with tempfile.TemporaryDirectory() as temp_dir: - store = ObStoreBackend(f"file://{temp_dir}") + store = ObStoreBackend(f"file://{tmp_path}") - with pytest.raises(MissingDependencyError, match="pyarrow"): - store.read_arrow("test.parquet") + with pytest.raises(MissingDependencyError, match="pyarrow"): + store.read_arrow("test.parquet") - with pytest.raises(MissingDependencyError, match="pyarrow"): - store.write_arrow("test.parquet", None) # type: ignore + with pytest.raises(MissingDependencyError, match="pyarrow"): + store.write_arrow("test.parquet", None) # type: ignore - with pytest.raises(MissingDependencyError, match="pyarrow"): - list(store.stream_arrow("*.parquet")) + with pytest.raises(MissingDependencyError, match="pyarrow"): + list(store.stream_arrow("*.parquet")) diff --git a/tests/unit/test_storage/test_storage_registry.py b/tests/unit/test_storage/test_storage_registry.py index 08f068f8..74c4f401 100644 --- a/tests/unit/test_storage/test_storage_registry.py +++ b/tests/unit/test_storage/test_storage_registry.py @@ -1,7 +1,6 @@ # pyright: reportPrivateImportUsage = false, reportPrivateUsage = false """Unit tests for StorageRegistry.""" -import tempfile from pathlib import Path import pytest @@ -46,95 +45,79 @@ def test_register_alias() -> None: assert "test_store" in registry.list_aliases() -def test_get_local_backend() -> None: +def test_get_local_backend(tmp_path: Path) -> None: """Test getting local backend (when explicitly requested).""" - with tempfile.TemporaryDirectory() as temp_dir: - registry = StorageRegistry() + registry = StorageRegistry() - # Force local backend with override - backend = registry.get(temp_dir, backend="local") - assert backend.backend_type == "local" + backend = registry.get(str(tmp_path), backend="local") + assert backend.backend_type == "local" - # Force local backend for file:// URI - backend = registry.get(f"file://{temp_dir}", backend="local") - assert backend.backend_type == "local" + backend = registry.get(f"file://{tmp_path}", backend="local") + assert backend.backend_type == "local" @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_get_local_backend_prefers_obstore() -> None: +def test_get_local_backend_prefers_obstore(tmp_path: Path) -> None: """Test that local paths prefer obstore when available.""" - with tempfile.TemporaryDirectory() as temp_dir: - registry = StorageRegistry() + registry = StorageRegistry() - # Without backend override, should get obstore for file:// - backend = registry.get(f"file://{temp_dir}") - assert backend.backend_type == "obstore" + backend = registry.get(f"file://{tmp_path}") + assert backend.backend_type == "obstore" - # Direct path should also prefer obstore - backend = registry.get(temp_dir) - assert backend.backend_type == "obstore" + backend = registry.get(str(tmp_path)) + assert backend.backend_type == "obstore" - # Can still force local backend with override - backend = registry.get(f"file://{temp_dir}", backend="local") - assert backend.backend_type == "local" + backend = registry.get(f"file://{tmp_path}", backend="local") + assert backend.backend_type == "local" -def test_get_local_backend_fallback_priority() -> None: +def test_get_local_backend_fallback_priority(tmp_path: Path) -> None: """Test backend fallback priority for local paths.""" - with tempfile.TemporaryDirectory() as temp_dir: - registry = StorageRegistry() + registry = StorageRegistry() - # Get whatever backend is available - backend = registry.get(f"file://{temp_dir}") + backend = registry.get(f"file://{tmp_path}") - # Should be one of: obstore > fsspec > local (in that priority order) - if OBSTORE_INSTALLED: - assert backend.backend_type == "obstore" - elif FSSPEC_INSTALLED: - assert backend.backend_type == "fsspec" - else: - assert backend.backend_type == "local" + if OBSTORE_INSTALLED: + assert backend.backend_type == "obstore" + elif FSSPEC_INSTALLED: + assert backend.backend_type == "fsspec" + else: + assert backend.backend_type == "local" -def test_get_alias() -> None: +def test_get_alias(tmp_path: Path) -> None: """Test getting backend by alias.""" - with tempfile.TemporaryDirectory() as temp_dir: - registry = StorageRegistry() - registry.register_alias("my_store", f"file://{temp_dir}") + registry = StorageRegistry() + registry.register_alias("my_store", f"file://{tmp_path}") - backend = registry.get("my_store") - # Backend type depends on what's installed - assert backend.backend_type in ("obstore", "fsspec", "local") + backend = registry.get("my_store") + assert backend.backend_type in ("obstore", "fsspec", "local") -def test_get_with_backend_override() -> None: +def test_get_with_backend_override(tmp_path: Path) -> None: """Test getting backend with override.""" - with tempfile.TemporaryDirectory() as temp_dir: - registry = StorageRegistry() + registry = StorageRegistry() - # Force local backend - backend = registry.get(f"file://{temp_dir}", backend="local") - assert backend.backend_type == "local" + backend = registry.get(f"file://{tmp_path}", backend="local") + assert backend.backend_type == "local" @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") -def test_get_fsspec_backend() -> None: +def test_get_fsspec_backend(tmp_path: Path) -> None: """Test getting fsspec backend.""" - with tempfile.TemporaryDirectory() as temp_dir: - registry = StorageRegistry() + registry = StorageRegistry() - backend = registry.get(f"file://{temp_dir}", backend="fsspec") - assert backend.backend_type == "fsspec" + backend = registry.get(f"file://{tmp_path}", backend="fsspec") + assert backend.backend_type == "fsspec" @pytest.mark.skipif(not OBSTORE_INSTALLED, reason="obstore not installed") -def test_get_obstore_backend() -> None: +def test_get_obstore_backend(tmp_path: Path) -> None: """Test getting obstore backend.""" - with tempfile.TemporaryDirectory() as temp_dir: - registry = StorageRegistry() + registry = StorageRegistry() - backend = registry.get(f"file://{temp_dir}", backend="obstore") - assert backend.backend_type == "obstore" + backend = registry.get(f"file://{tmp_path}", backend="obstore") + assert backend.backend_type == "obstore" def test_get_invalid_alias_raises_error() -> None: @@ -161,53 +144,45 @@ def test_get_invalid_backend_raises_error() -> None: registry.get("file:///tmp", backend="invalid") -def test_register_alias_with_base_path() -> None: +def test_register_alias_with_base_path(tmp_path: Path) -> None: """Test alias registration with base_path.""" - with tempfile.TemporaryDirectory() as temp_dir: - registry = StorageRegistry() + registry = StorageRegistry() - registry.register_alias("test_store", f"file://{temp_dir}/data") - backend = registry.get("test_store") + registry.register_alias("test_store", f"file://{tmp_path}/data") + backend = registry.get("test_store") - # Write and read to verify base_path works - backend.write_text("test.txt", "content") - assert backend.exists("test.txt") + backend.write_text("test.txt", "content") + assert backend.exists("test.txt") -def test_register_alias_with_backend_override() -> None: +def test_register_alias_with_backend_override(tmp_path: Path) -> None: """Test alias registration with backend override.""" - with tempfile.TemporaryDirectory() as temp_dir: - registry = StorageRegistry() + registry = StorageRegistry() - registry.register_alias("test_store", f"file://{temp_dir}", backend="local") - backend = registry.get("test_store") - assert backend.backend_type == "local" + registry.register_alias("test_store", f"file://{tmp_path}", backend="local") + backend = registry.get("test_store") + assert backend.backend_type == "local" -def test_cache_functionality() -> None: +def test_cache_functionality(tmp_path: Path) -> None: """Test registry caching.""" - with tempfile.TemporaryDirectory() as temp_dir: - registry = StorageRegistry() + registry = StorageRegistry() - # Get same backend twice - backend1 = registry.get(f"file://{temp_dir}") - backend2 = registry.get(f"file://{temp_dir}") + backend1 = registry.get(f"file://{tmp_path}") + backend2 = registry.get(f"file://{tmp_path}") - # Should be the same instance - assert backend1 is backend2 + assert backend1 is backend2 -def test_clear_cache() -> None: +def test_clear_cache(tmp_path: Path) -> None: """Test cache clearing.""" - with tempfile.TemporaryDirectory() as temp_dir: - registry = StorageRegistry() + registry = StorageRegistry() - backend1 = registry.get(f"file://{temp_dir}") - registry.clear_cache(f"file://{temp_dir}") - backend2 = registry.get(f"file://{temp_dir}") + backend1 = registry.get(f"file://{tmp_path}") + registry.clear_cache(f"file://{tmp_path}") + backend2 = registry.get(f"file://{tmp_path}") - # Should be different instances after cache clear - assert backend1 is not backend2 + assert backend1 is not backend2 def test_clear_aliases() -> None: @@ -222,47 +197,41 @@ def test_clear_aliases() -> None: assert len(registry.list_aliases()) == 0 -def test_clear_instances() -> None: +def test_clear_instances(tmp_path: Path) -> None: """Test clearing instances.""" - with tempfile.TemporaryDirectory() as temp_dir: - registry = StorageRegistry() + registry = StorageRegistry() - backend1 = registry.get(f"file://{temp_dir}") - registry.clear_instances() - backend2 = registry.get(f"file://{temp_dir}") + backend1 = registry.get(f"file://{tmp_path}") + registry.clear_instances() + backend2 = registry.get(f"file://{tmp_path}") - # Should be different instances after clear - assert backend1 is not backend2 + assert backend1 is not backend2 -def test_clear_all() -> None: +def test_clear_all(tmp_path: Path) -> None: """Test clearing everything.""" - with tempfile.TemporaryDirectory() as temp_dir: - registry = StorageRegistry() + registry = StorageRegistry() - registry.register_alias("test_store", f"file://{temp_dir}") - backend1 = registry.get("test_store") + registry.register_alias("test_store", f"file://{tmp_path}") + backend1 = registry.get("test_store") - registry.clear() + registry.clear() - assert not registry.is_alias_registered("test_store") - assert len(registry.list_aliases()) == 0 + assert not registry.is_alias_registered("test_store") + assert len(registry.list_aliases()) == 0 - # Should create new instance - registry.register_alias("test_store", f"file://{temp_dir}") - backend2 = registry.get("test_store") - assert backend1 is not backend2 + registry.register_alias("test_store", f"file://{tmp_path}") + backend2 = registry.get("test_store") + assert backend1 is not backend2 -def test_path_object_conversion() -> None: +def test_path_object_conversion(tmp_path: Path) -> None: """Test Path object conversion to file:// URI.""" - with tempfile.TemporaryDirectory() as temp_dir: - registry = StorageRegistry() - path_obj = Path(temp_dir) + registry = StorageRegistry() + path_obj = tmp_path - backend = registry.get(path_obj) - # Backend type depends on what's installed (obstore > fsspec > local) - assert backend.backend_type in ("obstore", "fsspec", "local") + backend = registry.get(path_obj) + assert backend.backend_type in ("obstore", "fsspec", "local") def test_cloud_storage_without_backends() -> None: diff --git a/tests/unit/test_utils/test_fixtures.py b/tests/unit/test_utils/test_fixtures.py index 2f9cd863..eba2dd5c 100644 --- a/tests/unit/test_utils/test_fixtures.py +++ b/tests/unit/test_utils/test_fixtures.py @@ -7,7 +7,6 @@ import gzip import json -import tempfile import zipfile from pathlib import Path from typing import Any @@ -28,115 +27,106 @@ pytestmark = pytest.mark.xdist_group("utils") -def test_find_fixture_file_json() -> None: +def test_find_fixture_file_json(tmp_path: Path) -> None: """Test finding regular .json fixture file.""" - with tempfile.TemporaryDirectory() as temp_dir: - fixtures_path = Path(temp_dir) - fixture_file = fixtures_path / "test.json" - fixture_file.write_text('{"test": "data"}') + fixtures_path = tmp_path + fixture_file = fixtures_path / "test.json" + fixture_file.write_text('{"test": "data"}') - result = _find_fixture_file(fixtures_path, "test") - assert result == fixture_file + result = _find_fixture_file(fixtures_path, "test") + assert result == fixture_file -def test_find_fixture_file_gz_priority() -> None: +def test_find_fixture_file_gz_priority(tmp_path: Path) -> None: """Test .json.gz takes priority over .json when both exist.""" - with tempfile.TemporaryDirectory() as temp_dir: - fixtures_path = Path(temp_dir) - json_file = fixtures_path / "test.json" - gz_file = fixtures_path / "test.json.gz" + fixtures_path = tmp_path + json_file = fixtures_path / "test.json" + gz_file = fixtures_path / "test.json.gz" - json_file.write_text('{"test": "json"}') - with gzip.open(gz_file, "wt") as f: - json.dump({"test": "gz"}, f) + json_file.write_text('{"test": "json"}') + with gzip.open(gz_file, "wt") as f: + json.dump({"test": "gz"}, f) - result = _find_fixture_file(fixtures_path, "test") - assert result == json_file # .json has highest priority + result = _find_fixture_file(fixtures_path, "test") + assert result == json_file # .json has highest priority -def test_find_fixture_file_zip_fallback() -> None: +def test_find_fixture_file_zip_fallback(tmp_path: Path) -> None: """Test .json.zip is found when .json and .json.gz don't exist.""" - with tempfile.TemporaryDirectory() as temp_dir: - fixtures_path = Path(temp_dir) - zip_file = fixtures_path / "test.json.zip" + fixtures_path = tmp_path + zip_file = fixtures_path / "test.json.zip" - with zipfile.ZipFile(zip_file, "w") as zf: - zf.writestr("test.json", '{"test": "zip"}') + with zipfile.ZipFile(zip_file, "w") as zf: + zf.writestr("test.json", '{"test": "zip"}') - result = _find_fixture_file(fixtures_path, "test") - assert result == zip_file + result = _find_fixture_file(fixtures_path, "test") + assert result == zip_file -def test_find_fixture_file_not_found() -> None: +def test_find_fixture_file_not_found(tmp_path: Path) -> None: """Test FileNotFoundError when no fixture file exists.""" - with tempfile.TemporaryDirectory() as temp_dir: - fixtures_path = Path(temp_dir) + fixtures_path = tmp_path - with pytest.raises(FileNotFoundError, match="Could not find the missing fixture"): - _find_fixture_file(fixtures_path, "missing") + with pytest.raises(FileNotFoundError, match="Could not find the missing fixture"): + _find_fixture_file(fixtures_path, "missing") -def test_read_gzip_file() -> None: +def test_read_gzip_file(tmp_path: Path) -> None: """Test reading gzipped JSON file.""" - with tempfile.TemporaryDirectory() as temp_dir: - gz_file = Path(temp_dir) / "test.json.gz" - test_data = {"test": "gzipped data", "number": 42} + gz_file = tmp_path / "test.json.gz" + test_data = {"test": "gzipped data", "number": 42} - with gzip.open(gz_file, "wt", encoding="utf-8") as f: - json.dump(test_data, f) + with gzip.open(gz_file, "wt", encoding="utf-8") as f: + json.dump(test_data, f) - result = _read_compressed_file(gz_file) - assert json.loads(result) == test_data + result = _read_compressed_file(gz_file) + assert json.loads(result) == test_data -def test_read_zip_file_with_matching_name() -> None: +def test_read_zip_file_with_matching_name(tmp_path: Path) -> None: """Test reading ZIP file with matching JSON filename.""" - with tempfile.TemporaryDirectory() as temp_dir: - zip_file = Path(temp_dir) / "test.json.zip" - test_data = {"test": "zipped data", "array": [1, 2, 3]} + zip_file = tmp_path / "test.json.zip" + test_data = {"test": "zipped data", "array": [1, 2, 3]} - with zipfile.ZipFile(zip_file, "w") as zf: - zf.writestr("test.json", json.dumps(test_data)) + with zipfile.ZipFile(zip_file, "w") as zf: + zf.writestr("test.json", json.dumps(test_data)) - result = _read_compressed_file(zip_file) - assert json.loads(result) == test_data + result = _read_compressed_file(zip_file) + assert json.loads(result) == test_data -def test_read_zip_file_first_json() -> None: +def test_read_zip_file_first_json(tmp_path: Path) -> None: """Test reading ZIP file with first JSON file when no matching name.""" - with tempfile.TemporaryDirectory() as temp_dir: - zip_file = Path(temp_dir) / "archive.zip" - test_data = {"test": "first json file"} + zip_file = tmp_path / "archive.zip" + test_data = {"test": "first json file"} - with zipfile.ZipFile(zip_file, "w") as zf: - zf.writestr("data.json", json.dumps(test_data)) - zf.writestr("other.txt", "not json") + with zipfile.ZipFile(zip_file, "w") as zf: + zf.writestr("data.json", json.dumps(test_data)) + zf.writestr("other.txt", "not json") - result = _read_compressed_file(zip_file) - assert json.loads(result) == test_data + result = _read_compressed_file(zip_file) + assert json.loads(result) == test_data -def test_read_zip_file_no_json() -> None: +def test_read_zip_file_no_json(tmp_path: Path) -> None: """Test error when ZIP file contains no JSON files.""" - with tempfile.TemporaryDirectory() as temp_dir: - zip_file = Path(temp_dir) / "empty.zip" + zip_file = tmp_path / "empty.zip" - with zipfile.ZipFile(zip_file, "w") as zf: - zf.writestr("data.txt", "not json") + with zipfile.ZipFile(zip_file, "w") as zf: + zf.writestr("data.txt", "not json") - with pytest.raises(ValueError, match="No JSON file found in ZIP archive"): - _read_compressed_file(zip_file) + with pytest.raises(ValueError, match="No JSON file found in ZIP archive"): + _read_compressed_file(zip_file) -def test_read_unsupported_format() -> None: +def test_read_unsupported_format(tmp_path: Path) -> None: """Test error for unsupported compression format.""" - with tempfile.TemporaryDirectory() as temp_dir: - unsupported_file = Path(temp_dir) / "test.tar.gz" - unsupported_file.write_text("data") + unsupported_file = tmp_path / "test.tar.gz" + unsupported_file.write_text("data") - # gzip module attempts to read .tar.gz files and raises BadGzipFile - with pytest.raises(gzip.BadGzipFile): - _read_compressed_file(unsupported_file) + # gzip module attempts to read .tar.gz files and raises BadGzipFile + with pytest.raises(gzip.BadGzipFile): + _read_compressed_file(unsupported_file) def test_serialize_dict() -> None: @@ -233,170 +223,157 @@ def test_serialize_primitive_none() -> None: assert json.loads(result) == data -def test_open_fixture_valid_file() -> None: +def test_open_fixture_valid_file(tmp_path: Path) -> None: """Test open_fixture with valid JSON fixture file.""" - with tempfile.TemporaryDirectory() as temp_dir: - fixtures_path = Path(temp_dir) - fixture_file = fixtures_path / "test_fixture.json" + fixtures_path = tmp_path + fixture_file = fixtures_path / "test_fixture.json" - test_data = {"name": "test", "value": 42, "items": [1, 2, 3]} - with fixture_file.open("w") as f: - json.dump(test_data, f) + test_data = {"name": "test", "value": 42, "items": [1, 2, 3]} + with fixture_file.open("w") as f: + json.dump(test_data, f) - result = open_fixture(fixtures_path, "test_fixture") - assert result == test_data + result = open_fixture(fixtures_path, "test_fixture") + assert result == test_data -def test_open_fixture_gzipped() -> None: +def test_open_fixture_gzipped(tmp_path: Path) -> None: """Test open_fixture with gzipped JSON file.""" - with tempfile.TemporaryDirectory() as temp_dir: - fixtures_path = Path(temp_dir) - fixture_file = fixtures_path / "test.json.gz" + fixtures_path = tmp_path + fixture_file = fixtures_path / "test.json.gz" - test_data = {"compressed": True, "data": [1, 2, 3]} - with gzip.open(fixture_file, "wt", encoding="utf-8") as f: - json.dump(test_data, f) + test_data = {"compressed": True, "data": [1, 2, 3]} + with gzip.open(fixture_file, "wt", encoding="utf-8") as f: + json.dump(test_data, f) - result = open_fixture(fixtures_path, "test") - assert result == test_data + result = open_fixture(fixtures_path, "test") + assert result == test_data -def test_open_fixture_zipped() -> None: +def test_open_fixture_zipped(tmp_path: Path) -> None: """Test open_fixture with zipped JSON file.""" - with tempfile.TemporaryDirectory() as temp_dir: - fixtures_path = Path(temp_dir) - fixture_file = fixtures_path / "test.json.zip" + fixtures_path = tmp_path + fixture_file = fixtures_path / "test.json.zip" - test_data = {"zipped": True, "values": ["a", "b", "c"]} - with zipfile.ZipFile(fixture_file, "w") as zf: - zf.writestr("test.json", json.dumps(test_data)) + test_data = {"zipped": True, "values": ["a", "b", "c"]} + with zipfile.ZipFile(fixture_file, "w") as zf: + zf.writestr("test.json", json.dumps(test_data)) - result = open_fixture(fixtures_path, "test") - assert result == test_data + result = open_fixture(fixtures_path, "test") + assert result == test_data -def test_open_fixture_missing_file() -> None: +def test_open_fixture_missing_file(tmp_path: Path) -> None: """Test open_fixture with missing fixture file.""" - with tempfile.TemporaryDirectory() as temp_dir: - fixtures_path = Path(temp_dir) + fixtures_path = tmp_path - with pytest.raises(FileNotFoundError, match="Could not find the nonexistent fixture"): - open_fixture(fixtures_path, "nonexistent") + with pytest.raises(FileNotFoundError, match="Could not find the nonexistent fixture"): + open_fixture(fixtures_path, "nonexistent") -def test_open_fixture_invalid_json() -> None: +def test_open_fixture_invalid_json(tmp_path: Path) -> None: """Test open_fixture with invalid JSON.""" - with tempfile.TemporaryDirectory() as temp_dir: - fixtures_path = Path(temp_dir) - fixture_file = fixtures_path / "invalid.json" + fixtures_path = tmp_path + fixture_file = fixtures_path / "invalid.json" - with fixture_file.open("w") as f: - f.write("{ invalid json content") + with fixture_file.open("w") as f: + f.write("{ invalid json content") - with pytest.raises(Exception): - open_fixture(fixtures_path, "invalid") + with pytest.raises(Exception): + open_fixture(fixtures_path, "invalid") -async def test_open_fixture_async_valid_file() -> None: +async def test_open_fixture_async_valid_file(tmp_path: Path) -> None: """Test open_fixture_async with valid JSON fixture file.""" - with tempfile.TemporaryDirectory() as temp_dir: - fixtures_path = Path(temp_dir) - fixture_file = fixtures_path / "test_async.json" + fixtures_path = tmp_path + fixture_file = fixtures_path / "test_async.json" - test_data = {"async": True, "data": {"nested": "value"}} - with fixture_file.open("w") as f: - json.dump(test_data, f) + test_data = {"async": True, "data": {"nested": "value"}} + with fixture_file.open("w") as f: + json.dump(test_data, f) - result = await open_fixture_async(fixtures_path, "test_async") - assert result == test_data + result = await open_fixture_async(fixtures_path, "test_async") + assert result == test_data -async def test_open_fixture_async_gzipped() -> None: +async def test_open_fixture_async_gzipped(tmp_path: Path) -> None: """Test open_fixture_async with gzipped file.""" - with tempfile.TemporaryDirectory() as temp_dir: - fixtures_path = Path(temp_dir) - fixture_file = fixtures_path / "async_gz.json.gz" + fixtures_path = tmp_path + fixture_file = fixtures_path / "async_gz.json.gz" - test_data = {"async_compressed": True, "numbers": [1, 2, 3, 4, 5]} - with gzip.open(fixture_file, "wt", encoding="utf-8") as f: - json.dump(test_data, f) + test_data = {"async_compressed": True, "numbers": [1, 2, 3, 4, 5]} + with gzip.open(fixture_file, "wt", encoding="utf-8") as f: + json.dump(test_data, f) - result = await open_fixture_async(fixtures_path, "async_gz") - assert result == test_data + result = await open_fixture_async(fixtures_path, "async_gz") + assert result == test_data -async def test_open_fixture_async_zipped() -> None: +async def test_open_fixture_async_zipped(tmp_path: Path) -> None: """Test open_fixture_async with zipped file.""" - with tempfile.TemporaryDirectory() as temp_dir: - fixtures_path = Path(temp_dir) - fixture_file = fixtures_path / "async_zip.json.zip" + fixtures_path = tmp_path + fixture_file = fixtures_path / "async_zip.json.zip" - test_data = {"async_zipped": True, "items": ["x", "y", "z"]} - with zipfile.ZipFile(fixture_file, "w") as zf: - zf.writestr("async_zip.json", json.dumps(test_data)) + test_data = {"async_zipped": True, "items": ["x", "y", "z"]} + with zipfile.ZipFile(fixture_file, "w") as zf: + zf.writestr("async_zip.json", json.dumps(test_data)) - result = await open_fixture_async(fixtures_path, "async_zip") - assert result == test_data + result = await open_fixture_async(fixtures_path, "async_zip") + assert result == test_data -async def test_open_fixture_async_missing_file() -> None: +async def test_open_fixture_async_missing_file(tmp_path: Path) -> None: """Test open_fixture_async with missing fixture file.""" - with tempfile.TemporaryDirectory() as temp_dir: - fixtures_path = Path(temp_dir) + fixtures_path = tmp_path - with pytest.raises(FileNotFoundError, match="Could not find the missing_async fixture"): - await open_fixture_async(fixtures_path, "missing_async") + with pytest.raises(FileNotFoundError, match="Could not find the missing_async fixture"): + await open_fixture_async(fixtures_path, "missing_async") -def test_write_fixture_dict() -> None: +def test_write_fixture_dict(tmp_path: Path) -> None: """Test writing a dictionary fixture.""" - with tempfile.TemporaryDirectory() as temp_dir: - test_data: Any = {"name": "test", "value": 42, "active": True} + test_data: Any = {"name": "test", "value": 42, "active": True} - write_fixture(temp_dir, "test_dict", test_data) + write_fixture(str(tmp_path), "test_dict", test_data) - # Verify file was created - fixture_file = Path(temp_dir) / "test_dict.json" - assert fixture_file.exists() + # Verify file was created + fixture_file = tmp_path / "test_dict.json" + assert fixture_file.exists() - # Verify content - loaded_data = open_fixture(temp_dir, "test_dict") - assert loaded_data == test_data + # Verify content + loaded_data = open_fixture(tmp_path, "test_dict") + assert loaded_data == test_data -def test_write_fixture_list() -> None: +def test_write_fixture_list(tmp_path: Path) -> None: """Test writing a list fixture.""" - with tempfile.TemporaryDirectory() as temp_dir: - test_data: Any = [{"id": 1, "name": "first"}, {"id": 2, "name": "second"}] + test_data: Any = [{"id": 1, "name": "first"}, {"id": 2, "name": "second"}] - write_fixture(temp_dir, "test_list", test_data) - loaded_data = open_fixture(temp_dir, "test_list") - assert loaded_data == test_data + write_fixture(str(tmp_path), "test_list", test_data) + loaded_data = open_fixture(tmp_path, "test_list") + assert loaded_data == test_data -def test_write_fixture_compressed() -> None: +def test_write_fixture_compressed(tmp_path: Path) -> None: """Test writing a compressed fixture.""" - with tempfile.TemporaryDirectory() as temp_dir: - test_data: Any = {"compressed": True, "data": list(range(100))} + test_data: Any = {"compressed": True, "data": list(range(100))} - write_fixture(temp_dir, "test_compressed", test_data, compress=True) + write_fixture(str(tmp_path), "test_compressed", test_data, compress=True) - # Verify gzipped file was created - fixture_file = Path(temp_dir) / "test_compressed.json.gz" - assert fixture_file.exists() + # Verify gzipped file was created + fixture_file = tmp_path / "test_compressed.json.gz" + assert fixture_file.exists() - # Verify content can be read - loaded_data = open_fixture(temp_dir, "test_compressed") - assert loaded_data == test_data + # Verify content can be read + loaded_data = open_fixture(tmp_path, "test_compressed") + assert loaded_data == test_data -def test_write_fixture_storage_backend_error() -> None: +def test_write_fixture_storage_backend_error(tmp_path: Path) -> None: """Test error handling for invalid storage backend.""" - with tempfile.TemporaryDirectory() as temp_dir: - test_data: Any = {"test": "data"} + test_data: Any = {"test": "data"} - with pytest.raises(ValueError, match="Failed to get storage backend"): - write_fixture(temp_dir, "test", test_data, storage_backend="invalid://backend") + with pytest.raises(ValueError, match="Failed to get storage backend"): + write_fixture(str(tmp_path), "test", test_data, storage_backend="invalid://backend") @patch("sqlspec.utils.fixtures.storage_registry") @@ -413,41 +390,38 @@ def test_write_fixture_with_custom_backend(mock_registry: Mock) -> None: mock_storage.write_text.assert_called_once() -async def test_write_fixture_async_dict() -> None: +async def test_write_fixture_async_dict(tmp_path: Path) -> None: """Test async writing a dictionary fixture.""" - with tempfile.TemporaryDirectory() as temp_dir: - test_data: Any = {"async_write": True, "value": 123} + test_data: Any = {"async_write": True, "value": 123} - await write_fixture_async(temp_dir, "async_test", test_data) + await write_fixture_async(str(tmp_path), "async_test", test_data) - # Verify file was created and content is correct - loaded_data = await open_fixture_async(temp_dir, "async_test") - assert loaded_data == test_data + # Verify file was created and content is correct + loaded_data = await open_fixture_async(tmp_path, "async_test") + assert loaded_data == test_data -async def test_write_fixture_async_compressed() -> None: +async def test_write_fixture_async_compressed(tmp_path: Path) -> None: """Test async writing a compressed fixture.""" - with tempfile.TemporaryDirectory() as temp_dir: - test_data: Any = {"async_compressed": True, "large_data": list(range(50))} + test_data: Any = {"async_compressed": True, "large_data": list(range(50))} - await write_fixture_async(temp_dir, "async_compressed", test_data, compress=True) + await write_fixture_async(str(tmp_path), "async_compressed", test_data, compress=True) - # Verify gzipped file was created - fixture_file = Path(temp_dir) / "async_compressed.json.gz" - assert fixture_file.exists() + # Verify gzipped file was created + fixture_file = tmp_path / "async_compressed.json.gz" + assert fixture_file.exists() - # Verify content - loaded_data = await open_fixture_async(temp_dir, "async_compressed") - assert loaded_data == test_data + # Verify content + loaded_data = await open_fixture_async(tmp_path, "async_compressed") + assert loaded_data == test_data -async def test_write_fixture_async_storage_error() -> None: +async def test_write_fixture_async_storage_error(tmp_path: Path) -> None: """Test async error handling for invalid storage backend.""" - with tempfile.TemporaryDirectory() as temp_dir: - test_data: Any = {"test": "data"} + test_data: Any = {"test": "data"} - with pytest.raises(ValueError, match="Failed to get storage backend"): - await write_fixture_async(temp_dir, "test", test_data, storage_backend="invalid://backend") + with pytest.raises(ValueError, match="Failed to get storage backend"): + await write_fixture_async(str(tmp_path), "test", test_data, storage_backend="invalid://backend") @patch("sqlspec.utils.fixtures.storage_registry") @@ -466,62 +440,59 @@ async def test_write_fixture_async_custom_backend(mock_registry: Mock) -> None: mock_storage.write_text_async.assert_called_once() -def test_write_read_roundtrip() -> None: +def test_write_read_roundtrip(tmp_path: Path) -> None: """Test complete write and read roundtrip.""" - with tempfile.TemporaryDirectory() as temp_dir: - original_data: Any = { - "users": [{"id": 1, "name": "Alice", "active": True}, {"id": 2, "name": "Bob", "active": False}], - "metadata": {"version": "1.0", "created": "2024-01-01", "total_users": 2}, - } + original_data: Any = { + "users": [{"id": 1, "name": "Alice", "active": True}, {"id": 2, "name": "Bob", "active": False}], + "metadata": {"version": "1.0", "created": "2024-01-01", "total_users": 2}, + } - # Write fixture - write_fixture(temp_dir, "integration_test", original_data) + # Write fixture + write_fixture(str(tmp_path), "integration_test", original_data) - # Read fixture back - loaded_data = open_fixture(temp_dir, "integration_test") + # Read fixture back + loaded_data = open_fixture(tmp_path, "integration_test") - # Verify data integrity - assert loaded_data == original_data + # Verify data integrity + assert loaded_data == original_data -async def test_async_write_read_roundtrip() -> None: +async def test_async_write_read_roundtrip(tmp_path: Path) -> None: """Test complete async write and read roundtrip.""" - with tempfile.TemporaryDirectory() as temp_dir: - original_data: Any = { - "async_test": True, - "data": {"nested": {"deeply": {"value": 42}}}, - "list_data": [{"item": i} for i in range(10)], - } + original_data: Any = { + "async_test": True, + "data": {"nested": {"deeply": {"value": 42}}}, + "list_data": [{"item": i} for i in range(10)], + } - # Write fixture async - await write_fixture_async(temp_dir, "async_integration", original_data) + # Write fixture async + await write_fixture_async(str(tmp_path), "async_integration", original_data) - # Read fixture back async - loaded_data = await open_fixture_async(temp_dir, "async_integration") + # Read fixture back async + loaded_data = await open_fixture_async(tmp_path, "async_integration") - # Verify data integrity - assert loaded_data == original_data + # Verify data integrity + assert loaded_data == original_data -def test_compressed_roundtrip() -> None: +def test_compressed_roundtrip(tmp_path: Path) -> None: """Test write and read roundtrip with compression.""" - with tempfile.TemporaryDirectory() as temp_dir: - # Large data that benefits from compression - original_data: Any = { - "large_list": [{"id": i, "data": f"item_{i}" * 10} for i in range(100)], - "repeated_data": ["same_string"] * 50, - } - - # Write compressed - write_fixture(temp_dir, "compressed_test", original_data, compress=True) - - # Read back - loaded_data = open_fixture(temp_dir, "compressed_test") - - # Verify data integrity - assert loaded_data == original_data - - # Verify file is actually compressed - compressed_file = Path(temp_dir) / "compressed_test.json.gz" - assert compressed_file.exists() - assert compressed_file.suffix == ".gz" + # Large data that benefits from compression + original_data: Any = { + "large_list": [{"id": i, "data": f"item_{i}" * 10} for i in range(100)], + "repeated_data": ["same_string"] * 50, + } + + # Write compressed + write_fixture(str(tmp_path), "compressed_test", original_data, compress=True) + + # Read back + loaded_data = open_fixture(tmp_path, "compressed_test") + + # Verify data integrity + assert loaded_data == original_data + + # Verify file is actually compressed + compressed_file = tmp_path / "compressed_test.json.gz" + assert compressed_file.exists() + assert compressed_file.suffix == ".gz"