diff --git a/Makefile b/Makefile index 5101694d0..139399f77 100644 --- a/Makefile +++ b/Makefile @@ -118,7 +118,7 @@ clean: ## Cleanup temporary build a .PHONY: test test: ## Run the tests @echo "${INFO} Running test cases... 🧪" - @uv run pytest tests + @uv run pytest -n 2 --dist=loadgroup tests @echo "${OK} Tests complete ✨" .PHONY: test-all @@ -128,7 +128,7 @@ test-all: tests ## Run all tests .PHONY: coverage coverage: ## Run tests with coverage report @echo "${INFO} Running tests with coverage... 📊" - @uv run pytest --cov -n auto --quiet + @uv run pytest --cov -n 2 --dist=loadgroup --quiet @uv run coverage html >/dev/null 2>&1 @uv run coverage xml >/dev/null 2>&1 @echo "${OK} Coverage report generated ✨" diff --git a/pyproject.toml b/pyproject.toml index abf1a7eaf..40721e033 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -175,7 +175,7 @@ exclude_lines = [ ] [tool.pytest.ini_options] -addopts = "-ra -q --doctest-glob='*.md' --strict-markers --strict-config" +addopts = ["-q", "-ra"] asyncio_default_fixture_loop_scope = "function" asyncio_mode = "auto" filterwarnings = [ @@ -189,8 +189,31 @@ filterwarnings = [ "ignore::DeprecationWarning:websockets.connection", "ignore::DeprecationWarning:websockets.legacy", ] +markers = [ + "integration: marks tests that require an external database", + "postgres: marks tests specific to PostgreSQL", + "duckdb: marks tests specific to DuckDB", + "sqlite: marks tests specific to SQLite", + "bigquery: marks tests specific to Google BigQuery", + "mysql: marks tests specific to MySQL", + "oracle: marks tests specific to Oracle", + "spanner: marks tests specific to Google Cloud Spanner", + "mssql: marks tests specific to Microsoft SQL Server", + # Driver markers + "adbc: marks tests using ADBC drivers", + "aioodbc: marks tests using aioodbc", + "aiosqlite: marks tests using aiosqlite", + "asyncmy: marks tests using asyncmy", + "asyncpg: marks tests using asyncpg", + "duckdb_driver: marks tests using the duckdb driver", + "google_bigquery: marks tests using google-cloud-bigquery", + "google_spanner: marks tests using google-cloud-spanner", + "oracledb: marks tests using oracledb", + "psycopg: marks tests using psycopg", + "pymssql: marks tests using pymssql", + "pymysql: marks tests using pymysql", +] testpaths = ["tests"] -xfail_strict = true [tool.mypy] packages = ["sqlspec", "tests"] @@ -220,6 +243,8 @@ module = [ "uvloop.*", "asyncmy", "asyncmy.*", + "pyarrow", + "pyarrow.*", ] [tool.pyright] diff --git a/sqlspec/_typing.py b/sqlspec/_typing.py index 6f9357b7a..d1ed14c83 100644 --- a/sqlspec/_typing.py +++ b/sqlspec/_typing.py @@ -1,8 +1,10 @@ +# ruff: noqa: RUF100, PLR0913, A002, DOC201, PLR6301 """This is a simple wrapper around a few important classes in each library. This is used to ensure compatibility when one or more of the libraries are installed. """ +from collections.abc import Iterable, Mapping from enum import Enum from typing import ( Any, @@ -96,7 +98,7 @@ def __init__( def validate_python( self, - object: Any, # noqa: A002 + object: Any, /, *, strict: "Optional[bool]" = None, @@ -127,10 +129,7 @@ class FailFast: # type: ignore[no-redef] except ImportError: import enum from collections.abc import Iterable - from typing import TYPE_CHECKING, Callable, Optional, Union - - if TYPE_CHECKING: - from collections.abc import Iterable + from typing import Callable, Optional, Union @dataclass_transform() @runtime_checkable @@ -174,7 +173,6 @@ def __init__(self, backend: Any, data_as_builtins: Any) -> None: """Placeholder init""" def create_instance(self, **kwargs: Any) -> "T": - """Placeholder implementation""" return cast("T", kwargs) def update_instance(self, instance: "T", **kwargs: Any) -> "T": @@ -198,11 +196,46 @@ class EmptyEnum(Enum): Empty: Final = EmptyEnum.EMPTY +try: + from pyarrow import Table as ArrowTable + + PYARROW_INSTALLED = True +except ImportError: + + @runtime_checkable + class ArrowTable(Protocol): # type: ignore[no-redef] + """Placeholder Implementation""" + + def to_batches(self, batch_size: int) -> Any: ... + def num_rows(self) -> int: ... + def num_columns(self) -> int: ... + def to_pydict(self) -> dict[str, Any]: ... + def to_string(self) -> str: ... + def from_arrays( + self, + arrays: list[Any], + names: "Optional[list[str]]" = None, + schema: "Optional[Any]" = None, + metadata: "Optional[Mapping[str, Any]]" = None, + ) -> Any: ... + def from_pydict( + self, + mapping: dict[str, Any], + schema: "Optional[Any]" = None, + metadata: "Optional[Mapping[str, Any]]" = None, + ) -> Any: ... + def from_batches(self, batches: Iterable[Any], schema: Optional[Any] = None) -> Any: ... + + PYARROW_INSTALLED = False # pyright: ignore[reportConstantRedefinition] + + __all__ = ( "LITESTAR_INSTALLED", "MSGSPEC_INSTALLED", + "PYARROW_INSTALLED", "PYDANTIC_INSTALLED", "UNSET", + "ArrowTable", "BaseModel", "DTOData", "DataclassProtocol", diff --git a/sqlspec/adapters/adbc/__init__.py b/sqlspec/adapters/adbc/__init__.py index 75c5b4718..f7432e597 100644 --- a/sqlspec/adapters/adbc/__init__.py +++ b/sqlspec/adapters/adbc/__init__.py @@ -1,7 +1,7 @@ -from sqlspec.adapters.adbc.config import Adbc +from sqlspec.adapters.adbc.config import AdbcConfig from sqlspec.adapters.adbc.driver import AdbcDriver __all__ = ( - "Adbc", + "AdbcConfig", "AdbcDriver", ) diff --git a/sqlspec/adapters/adbc/config.py b/sqlspec/adapters/adbc/config.py index 7758741ac..86fcb738b 100644 --- a/sqlspec/adapters/adbc/config.py +++ b/sqlspec/adapters/adbc/config.py @@ -14,11 +14,11 @@ from collections.abc import Generator -__all__ = ("Adbc",) +__all__ = ("AdbcConfig",) @dataclass -class Adbc(NoPoolSyncConfig["Connection", "AdbcDriver"]): +class AdbcConfig(NoPoolSyncConfig["Connection", "AdbcDriver"]): """Configuration for ADBC connections. This class provides configuration options for ADBC database connections using the @@ -55,17 +55,41 @@ def _set_adbc(self) -> str: # noqa: PLR0912 """ if isinstance(self.driver_name, str): - if self.driver_name != "adbc_driver_sqlite.dbapi.connect" and "sqlite" in self.driver_name: + if self.driver_name != "adbc_driver_sqlite.dbapi.connect" and self.driver_name in { + "sqlite", + "sqlite3", + "adbc_driver_sqlite", + }: self.driver_name = "adbc_driver_sqlite.dbapi.connect" - elif self.driver_name != "adbc_driver_duckdb.dbapi.connect" and "duckdb" in self.driver_name: + elif self.driver_name != "adbc_driver_duckdb.dbapi.connect" and self.driver_name in { + "duckdb", + "adbc_driver_duckdb", + }: self.driver_name = "adbc_driver_duckdb.dbapi.connect" - elif self.driver_name != "adbc_driver_postgresql.dbapi.connect" and "postgres" in self.driver_name: + elif self.driver_name != "adbc_driver_postgresql.dbapi.connect" and self.driver_name in { + "postgres", + "adbc_driver_postgresql", + "postgresql", + "pg", + }: self.driver_name = "adbc_driver_postgresql.dbapi.connect" - elif self.driver_name != "adbc_driver_snowflake.dbapi.connect" and "snowflake" in self.driver_name: + elif self.driver_name != "adbc_driver_snowflake.dbapi.connect" and self.driver_name in { + "snowflake", + "adbc_driver_snowflake", + "sf", + }: self.driver_name = "adbc_driver_snowflake.dbapi.connect" - elif self.driver_name != "adbc_driver_bigquery.dbapi.connect" and "bigquery" in self.driver_name: + elif self.driver_name != "adbc_driver_bigquery.dbapi.connect" and self.driver_name in { + "bigquery", + "adbc_driver_bigquery", + "bq", + }: self.driver_name = "adbc_driver_bigquery.dbapi.connect" - elif self.driver_name != "adbc_driver_flightsql.dbapi.connect" and "flightsql" in self.driver_name: + elif self.driver_name != "adbc_driver_flightsql.dbapi.connect" and self.driver_name in { + "flightsql", + "adbc_driver_flightsql", + "grpc", + }: self.driver_name = "adbc_driver_flightsql.dbapi.connect" return self.driver_name @@ -153,11 +177,10 @@ def create_connection(self) -> "Connection": """ try: connect_func = self._get_connect_func() - _config = self.connection_config_dict - return connect_func(**_config) + return connect_func(**self.connection_config_dict) except Exception as e: # Include driver name in error message for better context - driver_name = self.driver_name if isinstance(self.driver_name, str) else "Unknown/Derived" + driver_name = self.driver_name if isinstance(self.driver_name, str) else "Unknown/Missing" # Use the potentially modified driver_path from _get_connect_func if available, # otherwise fallback to self.driver_name for the error message. # This requires _get_connect_func to potentially return the used path or store it. diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index f0595fd79..309b93baf 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -2,14 +2,16 @@ import re from collections.abc import Generator from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast -from adbc_driver_manager.dbapi import Connection, Cursor +from adbc_driver_manager.dbapi import Connection +from adbc_driver_manager.dbapi import Cursor as DbapiCursor -from sqlspec.base import SyncDriverAdapterProtocol, T +from sqlspec._typing import ArrowTable +from sqlspec.base import SyncArrowBulkOperationsMixin, SyncDriverAdapterProtocol, T if TYPE_CHECKING: - from sqlspec.typing import ModelDTOT, StatementParameterType + from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType __all__ = ("AdbcDriver",) @@ -26,10 +28,11 @@ ) -class AdbcDriver(SyncDriverAdapterProtocol["Connection"]): +class AdbcDriver(SyncArrowBulkOperationsMixin["Connection"], SyncDriverAdapterProtocol["Connection"]): """ADBC Sync Driver Adapter.""" connection: Connection + __supports_arrow__: ClassVar[bool] = True def __init__(self, connection: "Connection") -> None: """Initialize the ADBC driver adapter.""" @@ -38,12 +41,12 @@ def __init__(self, connection: "Connection") -> None: # For now, assume 'qmark' based on typical ADBC DBAPI behavior @staticmethod - def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "Cursor": + def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "DbapiCursor": return connection.cursor(*args, **kwargs) @contextmanager - def _with_cursor(self, connection: "Connection") -> Generator["Cursor", None, None]: - cursor = self._cursor(connection) + def _with_cursor(self, connection: "Connection") -> Generator["DbapiCursor", None, None]: + cursor: DbapiCursor = self._cursor(connection) try: yield cursor finally: @@ -331,3 +334,24 @@ def execute_script_returning( if schema_type is not None: return cast("ModelDTOT", schema_type(**dict(zip(column_names, result[0])))) # pyright: ignore[reportUnknownArgumentType] return dict(zip(column_names, result[0])) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] + + # --- Arrow Bulk Operations --- + + def select_arrow( # pyright: ignore[reportUnknownParameterType] + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + connection: "Optional[Connection]" = None, + ) -> "ArrowTable": + """Execute a SQL query and return results as an Apache Arrow Table. + + Returns: + The results of the query as an Apache Arrow Table. + """ + conn = self._connection(connection) + sql, parameters = self._process_sql_params(sql, parameters) + + with self._with_cursor(conn) as cursor: + cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] + return cast("ArrowTable", cursor.fetch_arrow_table()) # pyright: ignore[reportUnknownMemberType] diff --git a/sqlspec/adapters/aiosqlite/__init__.py b/sqlspec/adapters/aiosqlite/__init__.py index cdefbea26..18c21a803 100644 --- a/sqlspec/adapters/aiosqlite/__init__.py +++ b/sqlspec/adapters/aiosqlite/__init__.py @@ -1,7 +1,7 @@ -from sqlspec.adapters.aiosqlite.config import Aiosqlite +from sqlspec.adapters.aiosqlite.config import AiosqliteConfig from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver __all__ = ( - "Aiosqlite", + "AiosqliteConfig", "AiosqliteDriver", ) diff --git a/sqlspec/adapters/aiosqlite/config.py b/sqlspec/adapters/aiosqlite/config.py index 6ae47c2d0..87512c768 100644 --- a/sqlspec/adapters/aiosqlite/config.py +++ b/sqlspec/adapters/aiosqlite/config.py @@ -15,11 +15,11 @@ from typing import Literal -__all__ = ("Aiosqlite",) +__all__ = ("AiosqliteConfig",) @dataclass -class Aiosqlite(NoPoolAsyncConfig["Connection", "AiosqliteDriver"]): +class AiosqliteConfig(NoPoolAsyncConfig["Connection", "AiosqliteDriver"]): """Configuration for Aiosqlite database connections. This class provides configuration options for Aiosqlite database connections, wrapping all parameters diff --git a/sqlspec/adapters/asyncmy/__init__.py b/sqlspec/adapters/asyncmy/__init__.py index 00d4aa39d..6774c059d 100644 --- a/sqlspec/adapters/asyncmy/__init__.py +++ b/sqlspec/adapters/asyncmy/__init__.py @@ -1,8 +1,8 @@ -from sqlspec.adapters.asyncmy.config import Asyncmy, AsyncmyPool +from sqlspec.adapters.asyncmy.config import AsyncmyConfig, AsyncmyPoolConfig from sqlspec.adapters.asyncmy.driver import AsyncmyDriver # type: ignore[attr-defined] __all__ = ( - "Asyncmy", + "AsyncmyConfig", "AsyncmyDriver", - "AsyncmyPool", + "AsyncmyPoolConfig", ) diff --git a/sqlspec/adapters/asyncmy/config.py b/sqlspec/adapters/asyncmy/config.py index d51edf6af..94f5ec959 100644 --- a/sqlspec/adapters/asyncmy/config.py +++ b/sqlspec/adapters/asyncmy/config.py @@ -16,8 +16,8 @@ from asyncmy.pool import Pool # pyright: ignore[reportUnknownVariableType] __all__ = ( - "Asyncmy", - "AsyncmyPool", + "AsyncmyConfig", + "AsyncmyPoolConfig", ) @@ -25,7 +25,7 @@ @dataclass -class AsyncmyPool(GenericPoolConfig): +class AsyncmyPoolConfig(GenericPoolConfig): """Configuration for Asyncmy's connection pool. This class provides configuration options for Asyncmy database connection pools. @@ -104,19 +104,19 @@ def pool_config_dict(self) -> "dict[str, Any]": @dataclass -class Asyncmy(AsyncDatabaseConfig["Connection", "Pool", "AsyncmyDriver"]): +class AsyncmyConfig(AsyncDatabaseConfig["Connection", "Pool", "AsyncmyDriver"]): """Asyncmy Configuration.""" __is_async__ = True __supports_connection_pooling__ = True - pool_config: "Optional[AsyncmyPool]" = None + pool_config: "Optional[AsyncmyPoolConfig]" = None """Asyncmy Pool configuration""" - connection_type: "type[Connection]" = field(init=False, default_factory=lambda: Connection) # pyright: ignore + connection_type: "type[Connection]" = field(hash=False, init=False, default_factory=lambda: Connection) # pyright: ignore """Type of the connection object""" - driver_type: "type[AsyncmyDriver]" = field(init=False, default_factory=lambda: AsyncmyDriver) + driver_type: "type[AsyncmyDriver]" = field(hash=False, init=False, default_factory=lambda: AsyncmyDriver) """Type of the driver object""" - pool_instance: "Optional[Pool]" = None # pyright: ignore[reportUnknownVariableType] + pool_instance: "Optional[Pool]" = field(hash=False, default=None) # pyright: ignore[reportUnknownVariableType] """Instance of the pool""" @property diff --git a/sqlspec/adapters/asyncpg/__init__.py b/sqlspec/adapters/asyncpg/__init__.py index d3f1d9f6a..a4ad19b2f 100644 --- a/sqlspec/adapters/asyncpg/__init__.py +++ b/sqlspec/adapters/asyncpg/__init__.py @@ -1,9 +1,9 @@ -from sqlspec.adapters.asyncpg.config import Asyncpg, AsyncpgPool, PgConnection -from sqlspec.adapters.asyncpg.driver import AsyncpgDriver +from sqlspec.adapters.asyncpg.config import AsyncpgConfig, AsyncpgPoolConfig +from sqlspec.adapters.asyncpg.driver import AsyncpgConnection, AsyncpgDriver __all__ = ( - "Asyncpg", + "AsyncpgConfig", + "AsyncpgConnection", "AsyncpgDriver", - "AsyncpgPool", - "PgConnection", + "AsyncpgPoolConfig", ) diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index 5882db381..ad61811e3 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -5,10 +5,9 @@ from asyncpg import Record from asyncpg import create_pool as asyncpg_create_pool from asyncpg.pool import PoolConnectionProxy -from typing_extensions import TypeAlias from sqlspec._serialization import decode_json, encode_json -from sqlspec.adapters.asyncpg.driver import AsyncpgDriver +from sqlspec.adapters.asyncpg.driver import AsyncpgConnection, AsyncpgDriver from sqlspec.base import AsyncDatabaseConfig, GenericPoolConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import Empty, EmptyType, dataclass_to_dict @@ -22,18 +21,16 @@ __all__ = ( - "Asyncpg", - "AsyncpgPool", + "AsyncpgConfig", + "AsyncpgPoolConfig", ) T = TypeVar("T") -PgConnection: TypeAlias = "Union[Connection[Any], PoolConnectionProxy[Any]]" - @dataclass -class AsyncpgPool(GenericPoolConfig): +class AsyncpgPoolConfig(GenericPoolConfig): """Configuration for Asyncpg's :class:`Pool `. For details see: https://magicstack.github.io/asyncpg/current/api/index.html#connection-pools @@ -73,23 +70,25 @@ class AsyncpgPool(GenericPoolConfig): @dataclass -class Asyncpg(AsyncDatabaseConfig["PgConnection", "Pool", "AsyncpgDriver"]): # pyright: ignore[reportMissingTypeArgument] +class AsyncpgConfig(AsyncDatabaseConfig["AsyncpgConnection", "Pool", "AsyncpgDriver"]): # pyright: ignore[reportMissingTypeArgument] """Asyncpg Configuration.""" - pool_config: "Optional[AsyncpgPool]" = None + pool_config: "Optional[AsyncpgPoolConfig]" = field(default=None) """Asyncpg Pool configuration""" - json_deserializer: "Callable[[str], Any]" = decode_json + json_deserializer: "Callable[[str], Any]" = field(hash=False, default=decode_json) """For dialects that support the :class:`JSON ` datatype, this is a Python callable that will convert a JSON string to a Python object. By default, this is set to SQLSpec's :attr:`decode_json() ` function.""" - json_serializer: "Callable[[Any], str]" = encode_json + json_serializer: "Callable[[Any], str]" = field(hash=False, default=encode_json) """For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON. By default, SQLSpec's :attr:`encode_json() ` is used.""" - connection_type: "type[PgConnection]" = field(init=False, default_factory=lambda: PoolConnectionProxy) + connection_type: "type[AsyncpgConnection]" = field( + hash=False, init=False, default_factory=lambda: PoolConnectionProxy + ) """Type of the connection object""" - driver_type: "type[AsyncpgDriver]" = field(init=False, default_factory=lambda: AsyncpgDriver) # type: ignore[type-abstract,unused-ignore] + driver_type: "type[AsyncpgDriver]" = field(hash=False, init=False, default_factory=lambda: AsyncpgDriver) # type: ignore[type-abstract,unused-ignore] """Type of the driver object""" - pool_instance: "Optional[Pool[Any]]" = None + pool_instance: "Optional[Pool[Any]]" = field(hash=False, default=None) """The connection pool instance. If set, this will be used instead of creating a new pool.""" @property @@ -174,7 +173,7 @@ def provide_pool(self, *args: "Any", **kwargs: "Any") -> "Awaitable[Pool]": # p """ return self.create_pool() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - async def create_connection(self) -> "PgConnection": + async def create_connection(self) -> "AsyncpgConnection": """Create and return a new asyncpg connection. Returns: diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index 6a0039f66..48ad1f67b 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -11,18 +11,18 @@ from sqlspec.typing import ModelDTOT, StatementParameterType -__all__ = ("AsyncpgDriver",) +__all__ = ("AsyncpgConnection", "AsyncpgDriver") -PgConnection: TypeAlias = "Union[Connection[Any], PoolConnectionProxy[Any]]" # pyright: ignore[reportMissingTypeArgument] +AsyncpgConnection: TypeAlias = "Union[Connection[Any], PoolConnectionProxy[Any]]" # pyright: ignore[reportMissingTypeArgument] -class AsyncpgDriver(AsyncDriverAdapterProtocol["PgConnection"]): +class AsyncpgDriver(AsyncDriverAdapterProtocol["AsyncpgConnection"]): """AsyncPG Postgres Driver Adapter.""" - connection: "PgConnection" + connection: "AsyncpgConnection" - def __init__(self, connection: "PgConnection") -> None: + def __init__(self, connection: "AsyncpgConnection") -> None: self.connection = connection def _process_sql_params( @@ -36,7 +36,7 @@ async def select( sql: str, parameters: Optional["StatementParameterType"] = None, /, - connection: Optional["PgConnection"] = None, + connection: Optional["AsyncpgConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -65,7 +65,7 @@ async def select_one( sql: str, parameters: Optional["StatementParameterType"] = None, /, - connection: Optional["PgConnection"] = None, + connection: Optional["AsyncpgConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -80,11 +80,9 @@ async def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, params = self._process_sql_params(sql, parameters) - # Use empty tuple if params is None - params = params if params is not None else () + sql, parameters = self._process_sql_params(sql, parameters) - result = await connection.fetchrow(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] result = self.check_not_found(result) if schema_type is None: @@ -97,7 +95,7 @@ async def select_one_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, - connection: Optional["PgConnection"] = None, + connection: Optional["AsyncpgConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -126,7 +124,7 @@ async def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, - connection: "Optional[PgConnection]" = None, + connection: "Optional[AsyncpgConnection]" = None, schema_type: "Optional[type[T]]" = None, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -135,11 +133,9 @@ async def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, params = self._process_sql_params(sql, parameters) - # Use empty tuple if params is None - params = params if params is not None else () + sql, parameters = self._process_sql_params(sql, parameters) - result = await connection.fetchval(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + result = await connection.fetchval(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] result = self.check_not_found(result) if schema_type is None: return result[0] @@ -150,7 +146,7 @@ async def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, - connection: "Optional[PgConnection]" = None, + connection: "Optional[AsyncpgConnection]" = None, schema_type: "Optional[type[T]]" = None, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -159,11 +155,9 @@ async def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, params = self._process_sql_params(sql, parameters) - # Use empty tuple if params is None - params = params if params is not None else () + sql, parameters = self._process_sql_params(sql, parameters) - result = await connection.fetchval(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + result = await connection.fetchval(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] if result is None: return None if schema_type is None: @@ -175,7 +169,7 @@ async def insert_update_delete( sql: str, parameters: Optional["StatementParameterType"] = None, /, - connection: Optional["PgConnection"] = None, + connection: Optional["AsyncpgConnection"] = None, ) -> int: """Insert, update, or delete data from the database. @@ -188,11 +182,9 @@ async def insert_update_delete( Row count affected by the operation. """ connection = self._connection(connection) - sql, params = self._process_sql_params(sql, parameters) - # Use empty tuple if params is None - params = params if params is not None else () + sql, parameters = self._process_sql_params(sql, parameters) - status = await connection.execute(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + status = await connection.execute(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] # AsyncPG returns a string like "INSERT 0 1" where the last number is the affected rows try: return int(status.split()[-1]) # pyright: ignore[reportUnknownMemberType] @@ -204,7 +196,7 @@ async def insert_update_delete_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, - connection: Optional["PgConnection"] = None, + connection: Optional["AsyncpgConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -219,11 +211,9 @@ async def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, params = self._process_sql_params(sql, parameters) - # Use empty tuple if params is None - params = params if params is not None else () + sql, parameters = self._process_sql_params(sql, parameters) - result = await connection.fetchrow(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] if result is None: return None if schema_type is None: @@ -236,7 +226,7 @@ async def execute_script( sql: str, parameters: Optional["StatementParameterType"] = None, /, - connection: Optional["PgConnection"] = None, + connection: Optional["AsyncpgConnection"] = None, ) -> str: """Execute a script. @@ -249,18 +239,16 @@ async def execute_script( Status message for the operation. """ connection = self._connection(connection) - sql, params = self._process_sql_params(sql, parameters) - # Use empty tuple if params is None - params = params if params is not None else () + sql, parameters = self._process_sql_params(sql, parameters) - return await connection.execute(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + return await connection.execute(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] async def execute_script_returning( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, - connection: Optional["PgConnection"] = None, + connection: Optional["AsyncpgConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Execute a script and return result. @@ -275,11 +263,9 @@ async def execute_script_returning( The first row of results. """ connection = self._connection(connection) - sql, params = self._process_sql_params(sql, parameters) - # Use empty tuple if params is None - params = params if params is not None else () + sql, parameters = self._process_sql_params(sql, parameters) - result = await connection.fetchrow(sql, *params) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] if result is None: return None if schema_type is None: diff --git a/sqlspec/adapters/duckdb/__init__.py b/sqlspec/adapters/duckdb/__init__.py index 1c3e3f638..39dfdae57 100644 --- a/sqlspec/adapters/duckdb/__init__.py +++ b/sqlspec/adapters/duckdb/__init__.py @@ -1,7 +1,7 @@ -from sqlspec.adapters.duckdb.config import DuckDB +from sqlspec.adapters.duckdb.config import DuckDBConfig from sqlspec.adapters.duckdb.driver import DuckDBDriver __all__ = ( - "DuckDB", + "DuckDBConfig", "DuckDBDriver", ) diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index df7400870..68270901b 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -14,7 +14,7 @@ from collections.abc import Generator, Sequence -__all__ = ("DuckDB", "ExtensionConfig") +__all__ = ("DuckDBConfig", "ExtensionConfig") class ExtensionConfig(TypedDict): @@ -69,7 +69,7 @@ class SecretConfig(TypedDict): @dataclass -class DuckDB(NoPoolSyncConfig["DuckDBPyConnection", "DuckDBDriver"]): +class DuckDBConfig(NoPoolSyncConfig["DuckDBPyConnection", "DuckDBDriver"]): """Configuration for DuckDB database connections. This class provides configuration options for DuckDB database connections, wrapping all parameters diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index db2f6a78d..6b01453c9 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -1,24 +1,24 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast -from sqlspec.base import SyncDriverAdapterProtocol, T +from sqlspec._typing import ArrowTable +from sqlspec.base import SyncArrowBulkOperationsMixin, SyncDriverAdapterProtocol, T if TYPE_CHECKING: from collections.abc import Generator from duckdb import DuckDBPyConnection - from sqlspec.typing import ModelDTOT, StatementParameterType + from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType __all__ = ("DuckDBDriver",) -class DuckDBDriver(SyncDriverAdapterProtocol["DuckDBPyConnection"]): +class DuckDBDriver(SyncArrowBulkOperationsMixin["DuckDBPyConnection"], SyncDriverAdapterProtocol["DuckDBPyConnection"]): """DuckDB Sync Driver Adapter.""" connection: "DuckDBPyConnection" use_cursor: bool = True - # param_style is inherited from CommonDriverAttributes def __init__(self, connection: "DuckDBPyConnection", use_cursor: bool = True) -> None: self.connection = connection @@ -80,7 +80,7 @@ def select_one( with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - result = self.check_not_found(result) # pyright: ignore + result = self.check_not_found(result) # pyright: ignore column_names = [col[0] for col in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] if schema_type is not None: @@ -124,7 +124,7 @@ def select_value( with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - result = self.check_not_found(result) # pyright: ignore + result = self.check_not_found(result) # pyright: ignore if schema_type is None: return result[0] # pyright: ignore return schema_type(result[0]) # type: ignore[call-arg] @@ -223,3 +223,25 @@ def execute_script( with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return cast("str", getattr(cursor, "statusmessage", "DONE")) # pyright: ignore[reportUnknownMemberType] + + # --- Arrow Bulk Operations --- + + def select_arrow( # pyright: ignore[reportUnknownParameterType] + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + connection: "Optional[DuckDBPyConnection]" = None, + ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] + """Execute a SQL query and return results as an Apache Arrow Table. + + Returns: + An Apache Arrow Table containing the query results. + """ + + conn = self._connection(connection) + processed_sql, processed_params = self._process_sql_params(sql, parameters) + + with self._with_cursor(conn) as cursor: + cursor.execute(processed_sql, processed_params) # pyright: ignore[reportUnknownMemberType] + return cast("ArrowTable", cursor.fetch_arrow_table()) # pyright: ignore[reportUnknownMemberType] diff --git a/sqlspec/adapters/oracledb/__init__.py b/sqlspec/adapters/oracledb/__init__.py index 69d8f08c9..4a6af3d74 100644 --- a/sqlspec/adapters/oracledb/__init__.py +++ b/sqlspec/adapters/oracledb/__init__.py @@ -1,16 +1,16 @@ from sqlspec.adapters.oracledb.config import ( - OracleAsync, - OracleAsyncPool, - OracleSync, - OracleSyncPool, + OracleAsyncConfig, + OracleAsyncPoolConfig, + OracleSyncConfig, + OracleSyncPoolConfig, ) from sqlspec.adapters.oracledb.driver import OracleAsyncDriver, OracleSyncDriver __all__ = ( - "OracleAsync", + "OracleAsyncConfig", "OracleAsyncDriver", - "OracleAsyncPool", - "OracleSync", + "OracleAsyncPoolConfig", + "OracleSyncConfig", "OracleSyncDriver", - "OracleSyncPool", + "OracleSyncPoolConfig", ) diff --git a/sqlspec/adapters/oracledb/config/__init__.py b/sqlspec/adapters/oracledb/config/__init__.py index d4f400a00..e7d3c66ba 100644 --- a/sqlspec/adapters/oracledb/config/__init__.py +++ b/sqlspec/adapters/oracledb/config/__init__.py @@ -1,9 +1,9 @@ -from sqlspec.adapters.oracledb.config._asyncio import OracleAsync, OracleAsyncPool -from sqlspec.adapters.oracledb.config._sync import OracleSync, OracleSyncPool +from sqlspec.adapters.oracledb.config._asyncio import OracleAsyncConfig, OracleAsyncPoolConfig +from sqlspec.adapters.oracledb.config._sync import OracleSyncConfig, OracleSyncPoolConfig __all__ = ( - "OracleAsync", - "OracleAsyncPool", - "OracleSync", - "OracleSyncPool", + "OracleAsyncConfig", + "OracleAsyncPoolConfig", + "OracleSyncConfig", + "OracleSyncPoolConfig", ) diff --git a/sqlspec/adapters/oracledb/config/_asyncio.py b/sqlspec/adapters/oracledb/config/_asyncio.py index 681a9948c..4ee03c94a 100644 --- a/sqlspec/adapters/oracledb/config/_asyncio.py +++ b/sqlspec/adapters/oracledb/config/_asyncio.py @@ -18,18 +18,18 @@ __all__ = ( - "OracleAsync", - "OracleAsyncPool", + "OracleAsyncConfig", + "OracleAsyncPoolConfig", ) @dataclass -class OracleAsyncPool(OracleGenericPoolConfig["AsyncConnection", "AsyncConnectionPool"]): +class OracleAsyncPoolConfig(OracleGenericPoolConfig["AsyncConnection", "AsyncConnectionPool"]): """Async Oracle Pool Config""" @dataclass -class OracleAsync(AsyncDatabaseConfig["AsyncConnection", "AsyncConnectionPool", "OracleAsyncDriver"]): +class OracleAsyncConfig(AsyncDatabaseConfig["AsyncConnection", "AsyncConnectionPool", "OracleAsyncDriver"]): """Oracle Async database Configuration. This class provides the base configuration for Oracle database connections, extending @@ -42,7 +42,7 @@ class OracleAsync(AsyncDatabaseConfig["AsyncConnection", "AsyncConnectionPool", options.([2](https://python-oracledb.readthedocs.io/en/latest/user_guide/tuning.html)) """ - pool_config: "Optional[OracleAsyncPool]" = None + pool_config: "Optional[OracleAsyncPoolConfig]" = None """Oracle Pool configuration""" pool_instance: "Optional[AsyncConnectionPool]" = None """Optional pool to use. diff --git a/sqlspec/adapters/oracledb/config/_sync.py b/sqlspec/adapters/oracledb/config/_sync.py index e532225f3..3e85b384d 100644 --- a/sqlspec/adapters/oracledb/config/_sync.py +++ b/sqlspec/adapters/oracledb/config/_sync.py @@ -18,18 +18,18 @@ __all__ = ( - "OracleSync", - "OracleSyncPool", + "OracleSyncConfig", + "OracleSyncPoolConfig", ) @dataclass -class OracleSyncPool(OracleGenericPoolConfig["Connection", "ConnectionPool"]): +class OracleSyncPoolConfig(OracleGenericPoolConfig["Connection", "ConnectionPool"]): """Sync Oracle Pool Config""" @dataclass -class OracleSync(SyncDatabaseConfig["Connection", "ConnectionPool", "OracleSyncDriver"]): +class OracleSyncConfig(SyncDatabaseConfig["Connection", "ConnectionPool", "OracleSyncDriver"]): """Oracle Sync database Configuration. This class provides the base configuration for Oracle database connections, extending @@ -42,7 +42,7 @@ class OracleSync(SyncDatabaseConfig["Connection", "ConnectionPool", "OracleSyncD options.([2](https://python-oracledb.readthedocs.io/en/latest/user_guide/tuning.html)) """ - pool_config: "Optional[OracleSyncPool]" = None + pool_config: "Optional[OracleSyncPoolConfig]" = None """Oracle Pool configuration""" pool_instance: "Optional[ConnectionPool]" = None """Optional pool to use. diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index c4db3de7f..910c62e9f 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -1,19 +1,27 @@ from contextlib import asynccontextmanager, contextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast -from sqlspec.base import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol, T +from sqlspec.base import ( + AsyncArrowBulkOperationsMixin, + AsyncDriverAdapterProtocol, + SyncArrowBulkOperationsMixin, + SyncDriverAdapterProtocol, + T, +) +from sqlspec.typing import ArrowTable, StatementParameterType if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator from oracledb import AsyncConnection, AsyncCursor, Connection, Cursor - from sqlspec.typing import ModelDTOT, StatementParameterType + # Conditionally import ArrowTable for type checking + from sqlspec.typing import ModelDTOT __all__ = ("OracleAsyncDriver", "OracleSyncDriver") -class OracleSyncDriver(SyncDriverAdapterProtocol["Connection"]): +class OracleSyncDriver(SyncArrowBulkOperationsMixin["Connection"], SyncDriverAdapterProtocol["Connection"]): """Oracle Sync Driver Adapter.""" connection: "Connection" @@ -239,8 +247,28 @@ def execute_script( cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return str(cursor.rowcount) # pyright: ignore[reportUnknownMemberType] + def select_arrow( # pyright: ignore[reportUnknownParameterType] + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + connection: "Optional[Connection]" = None, + ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] + """Execute a SQL query and return results as an Apache Arrow Table. -class OracleAsyncDriver(AsyncDriverAdapterProtocol["AsyncConnection"]): + Returns: + An Apache Arrow Table containing the query results. + """ + + connection = self._connection(connection) + sql, parameters = self._process_sql_params(sql, parameters) + results = connection.fetch_df_all(sql, parameters) + return cast("ArrowTable", ArrowTable.from_arrays(arrays=results.column_arrays(), names=results.column_names())) # pyright: ignore + + +class OracleAsyncDriver( + AsyncArrowBulkOperationsMixin["AsyncConnection"], AsyncDriverAdapterProtocol["AsyncConnection"] +): """Oracle Async Driver Adapter.""" connection: "AsyncConnection" @@ -496,3 +524,26 @@ async def execute_script_returning( return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) # pyright: ignore[reportUnknownArgumentType] # Always return dictionaries return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] + + async def select_arrow( # pyright: ignore[reportUnknownParameterType] + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + connection: "Optional[AsyncConnection]" = None, + ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] + """Execute a SQL query asynchronously and return results as an Apache Arrow Table. + + Args: + sql: The SQL query string. + parameters: Parameters for the query. + connection: Optional connection override. + + Returns: + An Apache Arrow Table containing the query results. + """ + + connection = self._connection(connection) + sql, parameters = self._process_sql_params(sql, parameters) + results = await connection.fetch_df_all(sql, parameters) + return ArrowTable.from_arrays(arrays=results.column_arrays(), names=results.column_names()) # pyright: ignore diff --git a/sqlspec/adapters/psycopg/__init__.py b/sqlspec/adapters/psycopg/__init__.py index 6e86d63cd..1105f5b43 100644 --- a/sqlspec/adapters/psycopg/__init__.py +++ b/sqlspec/adapters/psycopg/__init__.py @@ -1,11 +1,16 @@ -from sqlspec.adapters.psycopg.config import PsycopgAsync, PsycopgAsyncPool, PsycopgSync, PsycopgSyncPool +from sqlspec.adapters.psycopg.config import ( + PsycopgAsyncConfig, + PsycopgAsyncPoolConfig, + PsycopgSyncConfig, + PsycopgSyncPoolConfig, +) from sqlspec.adapters.psycopg.driver import PsycopgAsyncDriver, PsycopgSyncDriver __all__ = ( - "PsycopgAsync", + "PsycopgAsyncConfig", "PsycopgAsyncDriver", - "PsycopgAsyncPool", - "PsycopgSync", + "PsycopgAsyncPoolConfig", + "PsycopgSyncConfig", "PsycopgSyncDriver", - "PsycopgSyncPool", + "PsycopgSyncPoolConfig", ) diff --git a/sqlspec/adapters/psycopg/config/__init__.py b/sqlspec/adapters/psycopg/config/__init__.py index 7d8481d09..a3ab74788 100644 --- a/sqlspec/adapters/psycopg/config/__init__.py +++ b/sqlspec/adapters/psycopg/config/__init__.py @@ -1,9 +1,9 @@ -from sqlspec.adapters.psycopg.config._async import PsycopgAsync, PsycopgAsyncPool -from sqlspec.adapters.psycopg.config._sync import PsycopgSync, PsycopgSyncPool +from sqlspec.adapters.psycopg.config._async import PsycopgAsyncConfig, PsycopgAsyncPoolConfig +from sqlspec.adapters.psycopg.config._sync import PsycopgSyncConfig, PsycopgSyncPoolConfig __all__ = ( - "PsycopgAsync", - "PsycopgAsyncPool", - "PsycopgSync", - "PsycopgSyncPool", + "PsycopgAsyncConfig", + "PsycopgAsyncPoolConfig", + "PsycopgSyncConfig", + "PsycopgSyncPoolConfig", ) diff --git a/sqlspec/adapters/psycopg/config/_async.py b/sqlspec/adapters/psycopg/config/_async.py index a7d602b5a..a7d1ad1fb 100644 --- a/sqlspec/adapters/psycopg/config/_async.py +++ b/sqlspec/adapters/psycopg/config/_async.py @@ -16,18 +16,18 @@ __all__ = ( - "PsycopgAsync", - "PsycopgAsyncPool", + "PsycopgAsyncConfig", + "PsycopgAsyncPoolConfig", ) @dataclass -class PsycopgAsyncPool(PsycopgGenericPoolConfig[AsyncConnection, AsyncConnectionPool]): +class PsycopgAsyncPoolConfig(PsycopgGenericPoolConfig[AsyncConnection, AsyncConnectionPool]): """Async Psycopg Pool Config""" @dataclass -class PsycopgAsync(AsyncDatabaseConfig[AsyncConnection, AsyncConnectionPool, PsycopgAsyncDriver]): +class PsycopgAsyncConfig(AsyncDatabaseConfig[AsyncConnection, AsyncConnectionPool, PsycopgAsyncDriver]): """Async Psycopg database Configuration. This class provides the base configuration for Psycopg database connections, extending @@ -37,7 +37,7 @@ class PsycopgAsync(AsyncDatabaseConfig[AsyncConnection, AsyncConnectionPool, Psy with both synchronous and asynchronous connections.([2](https://www.psycopg.org/psycopg3/docs/api/connections.html)) """ - pool_config: "Optional[PsycopgAsyncPool]" = None + pool_config: "Optional[PsycopgAsyncPoolConfig]" = None """Psycopg Pool configuration""" pool_instance: "Optional[AsyncConnectionPool]" = None """Optional pool to use""" @@ -71,7 +71,7 @@ def connection_config_dict(self) -> "dict[str, Any]": self.pool_config, exclude_empty=True, convert_nested=False, - exclude=pool_only_params.union({"pool_instance", "connection_type", "driver_type"}), + exclude=pool_only_params.union({"pool_instance", "connection_type", "driver_type", "open"}), ) msg = "You must provide a 'pool_config' for this adapter." raise ImproperConfigurationError(msg) @@ -128,7 +128,7 @@ async def create_pool(self) -> "AsyncConnectionPool": raise ImproperConfigurationError(msg) pool_config = self.pool_config_dict - self.pool_instance = AsyncConnectionPool(**pool_config) + self.pool_instance = AsyncConnectionPool(open=False, **pool_config) if self.pool_instance is None: # pyright: ignore[reportUnnecessaryComparison] msg = "Could not configure the 'pool_instance'. Please check your configuration." # type: ignore[unreachable] raise ImproperConfigurationError(msg) @@ -150,7 +150,7 @@ async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGener AsyncConnection: A database connection from the pool. """ pool = await self.provide_pool(*args, **kwargs) - async with pool.connection() as connection: + async with pool, pool.connection() as connection: yield connection @asynccontextmanager diff --git a/sqlspec/adapters/psycopg/config/_sync.py b/sqlspec/adapters/psycopg/config/_sync.py index ddfa09d46..40a6a0b9c 100644 --- a/sqlspec/adapters/psycopg/config/_sync.py +++ b/sqlspec/adapters/psycopg/config/_sync.py @@ -16,18 +16,18 @@ __all__ = ( - "PsycopgSync", - "PsycopgSyncPool", + "PsycopgSyncConfig", + "PsycopgSyncPoolConfig", ) @dataclass -class PsycopgSyncPool(PsycopgGenericPoolConfig[Connection, ConnectionPool]): +class PsycopgSyncPoolConfig(PsycopgGenericPoolConfig[Connection, ConnectionPool]): """Sync Psycopg Pool Config""" @dataclass -class PsycopgSync(SyncDatabaseConfig[Connection, ConnectionPool, PsycopgSyncDriver]): +class PsycopgSyncConfig(SyncDatabaseConfig[Connection, ConnectionPool, PsycopgSyncDriver]): """Sync Psycopg database Configuration. This class provides the base configuration for Psycopg database connections, extending the generic database configuration with Psycopg-specific settings.([1](https://www.psycopg.org/psycopg3/docs/api/connections.html)) @@ -36,7 +36,7 @@ class PsycopgSync(SyncDatabaseConfig[Connection, ConnectionPool, PsycopgSyncDriv with both synchronous and asynchronous connections.([2](https://www.psycopg.org/psycopg3/docs/api/connections.html)) """ - pool_config: "Optional[PsycopgSyncPool]" = None + pool_config: "Optional[PsycopgSyncPoolConfig]" = None """Psycopg Pool configuration""" pool_instance: "Optional[ConnectionPool]" = None """Optional pool to use""" @@ -70,7 +70,7 @@ def connection_config_dict(self) -> "dict[str, Any]": self.pool_config, exclude_empty=True, convert_nested=False, - exclude=pool_only_params.union({"pool_instance", "connection_type", "driver_type"}), + exclude=pool_only_params.union({"pool_instance", "connection_type", "driver_type", "open"}), ) msg = "You must provide a 'pool_config' for this adapter." raise ImproperConfigurationError(msg) @@ -87,7 +87,7 @@ def pool_config_dict(self) -> "dict[str, Any]": self.pool_config, exclude_empty=True, convert_nested=False, - exclude={"pool_instance", "connection_type", "driver_type"}, + exclude={"pool_instance", "connection_type", "driver_type", "open"}, ) msg = "'pool_config' methods can not be used when a 'pool_instance' is provided." raise ImproperConfigurationError(msg) @@ -127,7 +127,7 @@ def create_pool(self) -> "ConnectionPool": raise ImproperConfigurationError(msg) pool_config = self.pool_config_dict - self.pool_instance = ConnectionPool(**pool_config) + self.pool_instance = ConnectionPool(open=False, **pool_config) if self.pool_instance is None: # pyright: ignore[reportUnnecessaryComparison] msg = "Could not configure the 'pool_instance'. Please check your configuration." # type: ignore[unreachable] raise ImproperConfigurationError(msg) @@ -149,7 +149,7 @@ def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[Connec Connection: A database connection from the pool. """ pool = self.provide_pool(*args, **kwargs) - with pool.connection() as connection: + with pool, pool.connection() as connection: yield connection @contextmanager diff --git a/sqlspec/adapters/sqlite/__init__.py b/sqlspec/adapters/sqlite/__init__.py index af90cebc0..a97d890a5 100644 --- a/sqlspec/adapters/sqlite/__init__.py +++ b/sqlspec/adapters/sqlite/__init__.py @@ -1,7 +1,7 @@ -from sqlspec.adapters.sqlite.config import Sqlite +from sqlspec.adapters.sqlite.config import SqliteConfig from sqlspec.adapters.sqlite.driver import SqliteDriver __all__ = ( - "Sqlite", + "SqliteConfig", "SqliteDriver", ) diff --git a/sqlspec/adapters/sqlite/config.py b/sqlspec/adapters/sqlite/config.py index 37fb04e13..087306a65 100644 --- a/sqlspec/adapters/sqlite/config.py +++ b/sqlspec/adapters/sqlite/config.py @@ -12,11 +12,11 @@ from collections.abc import Generator -__all__ = ("Sqlite",) +__all__ = ("SqliteConfig",) @dataclass -class Sqlite(NoPoolSyncConfig["Connection", "SqliteDriver"]): +class SqliteConfig(NoPoolSyncConfig["Connection", "SqliteDriver"]): """Configuration for SQLite database connections. This class provides configuration options for SQLite database connections, wrapping all parameters diff --git a/sqlspec/base.py b/sqlspec/base.py index de36b9021..4e727950a 100644 --- a/sqlspec/base.py +++ b/sqlspec/base.py @@ -1,10 +1,10 @@ # ruff: noqa: PLR6301 import re from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, Awaitable, Generator -from contextlib import AbstractAsyncContextManager, AbstractContextManager +from collections.abc import Awaitable from dataclasses import dataclass, field from typing import ( + TYPE_CHECKING, Annotated, Any, ClassVar, @@ -19,13 +19,24 @@ from sqlspec.exceptions import NotFoundError from sqlspec.typing import ModelDTOT, StatementParameterType +if TYPE_CHECKING: + from contextlib import AbstractAsyncContextManager, AbstractContextManager + + from pyarrow import Table as ArrowTable + __all__ = ( + "AsyncArrowBulkOperationsMixin", "AsyncDatabaseConfig", + "AsyncDriverAdapterProtocol", + "CommonDriverAttributes", "DatabaseConfigProtocol", "GenericPoolConfig", "NoPoolAsyncConfig", "NoPoolSyncConfig", + "SQLSpec", + "SyncArrowBulkOperationsMixin", "SyncDatabaseConfig", + "SyncDriverAdapterProtocol", ) T = TypeVar("T") @@ -56,14 +67,14 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]): connection_type: "type[ConnectionT]" = field(init=False) driver_type: "type[DriverT]" = field(init=False) pool_instance: "Optional[PoolT]" = field(default=None) - __is_async__: ClassVar[bool] = False - __supports_connection_pooling__: ClassVar[bool] = False + __is_async__: "ClassVar[bool]" = False + __supports_connection_pooling__: "ClassVar[bool]" = False def __hash__(self) -> int: return id(self) @abstractmethod - def create_connection(self) -> Union[ConnectionT, Awaitable[ConnectionT]]: + def create_connection(self) -> "Union[ConnectionT, Awaitable[ConnectionT]]": """Create and return a new database connection.""" raise NotImplementedError @@ -72,28 +83,32 @@ def provide_connection( self, *args: Any, **kwargs: Any, - ) -> Union[ - Generator[ConnectionT, None, None], - AsyncGenerator[ConnectionT, None], - AbstractContextManager[ConnectionT], - AbstractAsyncContextManager[ConnectionT], - ]: + ) -> "Union[AbstractContextManager[ConnectionT], AbstractAsyncContextManager[ConnectionT]]": """Provide a database connection context manager.""" raise NotImplementedError + @abstractmethod + def provide_session( + self, + *args: Any, + **kwargs: Any, + ) -> "Union[AbstractContextManager[DriverT], AbstractAsyncContextManager[DriverT]]": + """Provide a database session context manager.""" + raise NotImplementedError + @property @abstractmethod - def connection_config_dict(self) -> dict[str, Any]: + def connection_config_dict(self) -> "dict[str, Any]": """Return the connection configuration as a dict.""" raise NotImplementedError @abstractmethod - def create_pool(self) -> Union[PoolT, Awaitable[PoolT]]: + def create_pool(self) -> "Union[PoolT, Awaitable[PoolT]]": """Create and return connection pool.""" raise NotImplementedError @abstractmethod - def close_pool(self) -> Optional[Awaitable[None]]: + def close_pool(self) -> "Optional[Awaitable[None]]": """Terminate the connection pool.""" raise NotImplementedError @@ -102,7 +117,7 @@ def provide_pool( self, *args: Any, **kwargs: Any, - ) -> Union[PoolT, Awaitable[PoolT], AbstractContextManager[PoolT], AbstractAsyncContextManager[PoolT]]: + ) -> "Union[PoolT, Awaitable[PoolT], AbstractContextManager[PoolT], AbstractAsyncContextManager[PoolT]]": """Provide pool instance.""" raise NotImplementedError @@ -185,18 +200,15 @@ def __init__(self) -> None: self._configs: dict[Any, DatabaseConfigProtocol[Any, Any, Any]] = {} @overload - def add_config(self, config: SyncConfigT) -> type[SyncConfigT]: ... + def add_config(self, config: "SyncConfigT") -> "type[SyncConfigT]": ... @overload - def add_config(self, config: AsyncConfigT) -> type[AsyncConfigT]: ... + def add_config(self, config: "AsyncConfigT") -> "type[AsyncConfigT]": ... def add_config( self, - config: Union[ - SyncConfigT, - AsyncConfigT, - ], - ) -> Union[Annotated[type[SyncConfigT], int], Annotated[type[AsyncConfigT], int]]: # pyright: ignore[reportInvalidTypeVarUse] + config: "Union[SyncConfigT, AsyncConfigT]", + ) -> "Union[Annotated[type[SyncConfigT], int], Annotated[type[AsyncConfigT], int]]": # pyright: ignore[reportInvalidTypeVarUse] """Add a new configuration to the manager. Returns: @@ -207,15 +219,15 @@ def add_config( return key # type: ignore[return-value] # pyright: ignore[reportReturnType] @overload - def get_config(self, name: type[SyncConfigT]) -> SyncConfigT: ... + def get_config(self, name: "type[SyncConfigT]") -> "SyncConfigT": ... @overload - def get_config(self, name: type[AsyncConfigT]) -> AsyncConfigT: ... + def get_config(self, name: "type[AsyncConfigT]") -> "AsyncConfigT": ... def get_config( self, - name: Union[type[DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]], Any], - ) -> DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]: + name: "Union[type[DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]], Any]", + ) -> "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]": """Retrieve a configuration by its type. Returns: @@ -234,61 +246,135 @@ def get_config( def get_connection( self, name: Union[ - type[NoPoolSyncConfig[ConnectionT, DriverT]], - type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]], # pyright: ignore[reportInvalidTypeVarUse] + "type[NoPoolSyncConfig[ConnectionT, DriverT]]", + "type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", # pyright: ignore[reportInvalidTypeVarUse] ], - ) -> ConnectionT: ... + ) -> "ConnectionT": ... @overload def get_connection( self, name: Union[ - type[NoPoolAsyncConfig[ConnectionT, DriverT]], - type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]], # pyright: ignore[reportInvalidTypeVarUse] + "type[NoPoolAsyncConfig[ConnectionT, DriverT]]", + "type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", # pyright: ignore[reportInvalidTypeVarUse] ], - ) -> Awaitable[ConnectionT]: ... + ) -> "Awaitable[ConnectionT]": ... def get_connection( self, name: Union[ - type[NoPoolSyncConfig[ConnectionT, DriverT]], - type[NoPoolAsyncConfig[ConnectionT, DriverT]], - type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]], - type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]], + "type[NoPoolSyncConfig[ConnectionT, DriverT]]", + "type[NoPoolAsyncConfig[ConnectionT, DriverT]]", + "type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + "type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", ], - ) -> Union[ConnectionT, Awaitable[ConnectionT]]: - """Create and return a connection from the specified configuration. + ) -> "Union[ConnectionT, Awaitable[ConnectionT]]": + """Create and return a new database connection from the specified configuration. Args: name: The configuration type to use for creating the connection. Returns: - Either a connection instance or an awaitable that resolves to a connection, - depending on whether the configuration is sync or async. + Either a connection instance or an awaitable that resolves to a connection instance. """ config = self.get_config(name) return config.create_connection() + def get_session( + self, + name: Union[ + "type[NoPoolSyncConfig[ConnectionT, DriverT]]", + "type[NoPoolAsyncConfig[ConnectionT, DriverT]]", + "type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + "type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + ], + ) -> "Union[DriverT, Awaitable[DriverT]]": + """Create and return a new database session from the specified configuration. + + Args: + name: The configuration type to use for creating the session. + + Returns: + Either a driver instance or an awaitable that resolves to a driver instance. + """ + config = self.get_config(name) + connection = self.get_connection(name) + if isinstance(connection, Awaitable): + + async def _create_session() -> DriverT: + return cast("DriverT", config.driver_type(await connection)) # pyright: ignore + + return _create_session() + return cast("DriverT", config.driver_type(connection)) # pyright: ignore + + def provide_connection( + self, + name: Union[ + "type[NoPoolSyncConfig[ConnectionT, DriverT]]", + "type[NoPoolAsyncConfig[ConnectionT, DriverT]]", + "type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + "type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + ], + *args: Any, + **kwargs: Any, + ) -> "Union[AbstractContextManager[ConnectionT], AbstractAsyncContextManager[ConnectionT]]": + """Create and provide a database connection from the specified configuration. + + Args: + name: The configuration type to use for creating the connection. + *args: Positional arguments to pass to the configuration's provide_connection method. + **kwargs: Keyword arguments to pass to the configuration's provide_connection method. + + Returns: + Either a synchronous or asynchronous context manager that provides a database connection. + """ + config = self.get_config(name) + return config.provide_connection(*args, **kwargs) + + def provide_session( + self, + name: Union[ + "type[NoPoolSyncConfig[ConnectionT, DriverT]]", + "type[NoPoolAsyncConfig[ConnectionT, DriverT]]", + "type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + "type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + ], + *args: Any, + **kwargs: Any, + ) -> "Union[AbstractContextManager[DriverT], AbstractAsyncContextManager[DriverT]]": + """Create and provide a database session from the specified configuration. + + Args: + name: The configuration type to use for creating the session. + *args: Positional arguments to pass to the configuration's provide_session method. + **kwargs: Keyword arguments to pass to the configuration's provide_session method. + + Returns: + Either a synchronous or asynchronous context manager that provides a database session. + """ + config = self.get_config(name) + return config.provide_session(*args, **kwargs) + @overload def get_pool( - self, name: type[Union[NoPoolSyncConfig[ConnectionT, DriverT], NoPoolAsyncConfig[ConnectionT, DriverT]]] - ) -> None: ... # pyright: ignore[reportInvalidTypeVarUse] + self, name: "type[Union[NoPoolSyncConfig[ConnectionT, DriverT], NoPoolAsyncConfig[ConnectionT, DriverT]]]" + ) -> "None": ... # pyright: ignore[reportInvalidTypeVarUse] @overload - def get_pool(self, name: type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]) -> type[PoolT]: ... # pyright: ignore[reportInvalidTypeVarUse] + def get_pool(self, name: "type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]") -> "type[PoolT]": ... # pyright: ignore[reportInvalidTypeVarUse] @overload - def get_pool(self, name: type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]) -> Awaitable[type[PoolT]]: ... # pyright: ignore[reportInvalidTypeVarUse] + def get_pool(self, name: "type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]") -> "Awaitable[type[PoolT]]": ... # pyright: ignore[reportInvalidTypeVarUse] def get_pool( self, name: Union[ - type[NoPoolSyncConfig[ConnectionT, DriverT]], - type[NoPoolAsyncConfig[ConnectionT, DriverT]], - type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]], - type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]], + "type[NoPoolSyncConfig[ConnectionT, DriverT]]", + "type[NoPoolAsyncConfig[ConnectionT, DriverT]]", + "type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + "type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", ], - ) -> Union[type[PoolT], Awaitable[type[PoolT]], None]: + ) -> "Union[type[PoolT], Awaitable[type[PoolT]], None]": """Create and return a connection pool from the specified configuration. Args: @@ -306,12 +392,12 @@ def get_pool( def close_pool( self, name: Union[ - type[NoPoolSyncConfig[ConnectionT, DriverT]], - type[NoPoolAsyncConfig[ConnectionT, DriverT]], - type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]], - type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]], + "type[NoPoolSyncConfig[ConnectionT, DriverT]]", + "type[NoPoolAsyncConfig[ConnectionT, DriverT]]", + "type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + "type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", ], - ) -> Optional[Awaitable[None]]: + ) -> "Optional[Awaitable[None]]": """Close the connection pool for the specified configuration. Args: @@ -333,6 +419,8 @@ class CommonDriverAttributes(Generic[ConnectionT]): """The parameter style placeholder supported by the underlying database driver (e.g., '?', '%s').""" connection: ConnectionT """The connection to the underlying database.""" + __supports_arrow__: ClassVar[bool] = False + """Indicates if the driver supports Apache Arrow operations.""" def _connection(self, connection: "Optional[ConnectionT]" = None) -> "ConnectionT": return connection if connection is not None else self.connection @@ -434,19 +522,48 @@ def _process_sql_params( return final_sql, processed_params -class SyncDriverAdapterProtocol(CommonDriverAttributes[ConnectionT], ABC, Generic[ConnectionT]): - connection: ConnectionT +class SyncArrowBulkOperationsMixin(Generic[ConnectionT]): + """Mixin for sync drivers supporting bulk Apache Arrow operations.""" + + __supports_arrow__: "ClassVar[bool]" = True def __init__(self, connection: ConnectionT) -> None: self.connection = connection + @abstractmethod + def select_arrow( # pyright: ignore[reportUnknownParameterType] + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + connection: "Optional[ConnectionT]" = None, + ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType] + """Execute a SQL query and return results as an Apache Arrow Table. + + Args: + sql: The SQL query string. + parameters: Parameters for the query. + connection: Optional connection override. + + Returns: + An Apache Arrow Table containing the query results. + """ + raise NotImplementedError + + +class SyncDriverAdapterProtocol(CommonDriverAttributes[ConnectionT], ABC, Generic[ConnectionT]): + connection: "ConnectionT" + + def __init__(self, connection: "ConnectionT") -> None: + self.connection = connection + @abstractmethod def select( self, sql: str, - parameters: Optional[StatementParameterType] = None, + parameters: "Optional[StatementParameterType]" = None, /, - connection: Optional[ConnectionT] = None, + connection: "Optional[ConnectionT]" = None, schema_type: Optional[type[ModelDTOT]] = None, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": ... @@ -519,88 +636,114 @@ def execute_script( ) -> str: ... +class AsyncArrowBulkOperationsMixin(Generic[ConnectionT]): + """Mixin for async drivers supporting bulk Apache Arrow operations.""" + + __supports_arrow__: "ClassVar[bool]" = True + + @abstractmethod + async def select_arrow( # pyright: ignore[reportUnknownParameterType] + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + connection: "Optional[ConnectionT]" = None, + ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType] + """Execute a SQL query and return results as an Apache Arrow Table. + + Args: + sql: The SQL query string. + parameters: Parameters for the query. + connection: Optional connection override. + + Returns: + An Apache Arrow Table containing the query results. + """ + raise NotImplementedError + + class AsyncDriverAdapterProtocol(CommonDriverAttributes[ConnectionT], ABC, Generic[ConnectionT]): - connection: ConnectionT + connection: "ConnectionT" - def __init__(self, connection: ConnectionT) -> None: + def __init__(self, connection: "ConnectionT") -> None: self.connection = connection @abstractmethod async def select( self, sql: str, - parameters: Optional[StatementParameterType] = None, + parameters: "Optional[StatementParameterType]" = None, /, - connection: Optional[ConnectionT] = None, - schema_type: Optional[type[ModelDTOT]] = None, + connection: "Optional[ConnectionT]" = None, + schema_type: "Optional[type[ModelDTOT]]" = None, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": ... @abstractmethod async def select_one( self, sql: str, - parameters: Optional[StatementParameterType] = None, + parameters: "Optional[StatementParameterType]" = None, /, - connection: Optional[ConnectionT] = None, - schema_type: Optional[type[ModelDTOT]] = None, + connection: "Optional[ConnectionT]" = None, + schema_type: "Optional[type[ModelDTOT]]" = None, ) -> "Union[ModelDTOT, dict[str, Any]]": ... @abstractmethod async def select_one_or_none( self, sql: str, - parameters: Optional[StatementParameterType] = None, + parameters: "Optional[StatementParameterType]" = None, /, - connection: Optional[ConnectionT] = None, - schema_type: Optional[type[ModelDTOT]] = None, + connection: "Optional[ConnectionT]" = None, + schema_type: "Optional[type[ModelDTOT]]" = None, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": ... @abstractmethod async def select_value( self, sql: str, - parameters: Optional[StatementParameterType] = None, + parameters: "Optional[StatementParameterType]" = None, /, - connection: Optional[ConnectionT] = None, - schema_type: Optional[type[T]] = None, + connection: "Optional[ConnectionT]" = None, + schema_type: "Optional[type[T]]" = None, ) -> "Union[Any, T]": ... @abstractmethod async def select_value_or_none( self, sql: str, - parameters: Optional[StatementParameterType] = None, + parameters: "Optional[StatementParameterType]" = None, /, - connection: Optional[ConnectionT] = None, - schema_type: Optional[type[T]] = None, + connection: "Optional[ConnectionT]" = None, + schema_type: "Optional[type[T]]" = None, ) -> "Optional[Union[Any, T]]": ... @abstractmethod async def insert_update_delete( self, sql: str, - parameters: Optional[StatementParameterType] = None, + parameters: "Optional[StatementParameterType]" = None, /, - connection: Optional[ConnectionT] = None, + connection: "Optional[ConnectionT]" = None, ) -> int: ... @abstractmethod async def insert_update_delete_returning( self, sql: str, - parameters: Optional[StatementParameterType] = None, + parameters: "Optional[StatementParameterType]" = None, /, - connection: Optional[ConnectionT] = None, - schema_type: Optional[type[ModelDTOT]] = None, + connection: "Optional[ConnectionT]" = None, + schema_type: "Optional[type[ModelDTOT]]" = None, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": ... @abstractmethod async def execute_script( self, sql: str, - parameters: Optional[StatementParameterType] = None, + parameters: "Optional[StatementParameterType]" = None, /, - connection: Optional[ConnectionT] = None, + connection: "Optional[ConnectionT]" = None, ) -> str: ... diff --git a/sqlspec/extensions/litestar/config.py b/sqlspec/extensions/litestar/config.py index 164246982..59e0ed167 100644 --- a/sqlspec/extensions/litestar/config.py +++ b/sqlspec/extensions/litestar/config.py @@ -12,6 +12,7 @@ lifespan_handler_maker, manual_handler_maker, pool_provider_maker, + session_provider_maker, ) if TYPE_CHECKING: @@ -25,6 +26,7 @@ from sqlspec.base import ( AsyncConfigT, ConnectionT, + DriverT, PoolT, SyncConfigT, ) @@ -33,6 +35,7 @@ DEFAULT_COMMIT_MODE: CommitMode = "manual" DEFAULT_CONNECTION_KEY = "db_connection" DEFAULT_POOL_KEY = "db_pool" +DEFAULT_SESSION_KEY = "db_session" @dataclass @@ -40,11 +43,13 @@ class DatabaseConfig: config: "Union[SyncConfigT, AsyncConfigT]" = field() # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues] connection_key: str = field(default=DEFAULT_CONNECTION_KEY) pool_key: str = field(default=DEFAULT_POOL_KEY) + session_key: str = field(default=DEFAULT_SESSION_KEY) commit_mode: "CommitMode" = field(default=DEFAULT_COMMIT_MODE) extra_commit_statuses: "Optional[set[int]]" = field(default=None) extra_rollback_statuses: "Optional[set[int]]" = field(default=None) connection_provider: "Callable[[State,Scope], Awaitable[ConnectionT]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues] pool_provider: "Callable[[State,Scope], Awaitable[PoolT]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues] + session_provider: "Callable[[State,Scope], Awaitable[DriverT]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues] before_send_handler: "BeforeMessageSendHookHandler" = field(init=False, repr=False, hash=False) lifespan_handler: "Callable[[Litestar], AbstractAsyncContextManager[None]]" = field( init=False, @@ -79,3 +84,4 @@ def __post_init__(self) -> None: self.lifespan_handler = lifespan_handler_maker(config=self.config, pool_key=self.pool_key) self.connection_provider = connection_provider_maker(connection_key=self.connection_key, config=self.config) self.pool_provider = pool_provider_maker(pool_key=self.pool_key, config=self.config) + self.session_provider = session_provider_maker(session_key=self.session_key, config=self.config) diff --git a/sqlspec/extensions/litestar/handlers.py b/sqlspec/extensions/litestar/handlers.py index a62d35fa3..1d5e336f9 100644 --- a/sqlspec/extensions/litestar/handlers.py +++ b/sqlspec/extensions/litestar/handlers.py @@ -186,3 +186,28 @@ async def provide_pool(state: "State", scope: "Scope") -> "PoolT": return cast("PoolT", pool) return provide_pool + + +def session_provider_maker( + session_key: str, + config: "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]", +) -> "Callable[[State,Scope], Awaitable[DriverT]]": + """Build the session provider for the database configuration. + + Args: + session_key: The dependency key to use for the session within Litestar. + config: The database configuration. + + Returns: + The generated session provider for the database. + """ + + async def provide_session(state: "State", scope: "Scope") -> "DriverT": + session = get_sqlspec_scope_state(scope, session_key) + if session is None: + connection = await maybe_async_(config.create_connection)() + session = config.driver_type(connection=connection) # pyright: ignore[reportCallIssue] + set_sqlspec_scope_state(scope, session_key, session) + return cast("DriverT", session) + + return provide_session diff --git a/sqlspec/extensions/litestar/plugin.py b/sqlspec/extensions/litestar/plugin.py index 651effc14..7eea9f783 100644 --- a/sqlspec/extensions/litestar/plugin.py +++ b/sqlspec/extensions/litestar/plugin.py @@ -23,6 +23,7 @@ DEFAULT_COMMIT_MODE: CommitMode = "manual" DEFAULT_CONNECTION_KEY = "db_connection" DEFAULT_POOL_KEY = "db_pool" +DEFAULT_SESSION_KEY = "db_session" class SQLSpec(InitPluginProtocol, SQLSpecBase): @@ -85,7 +86,11 @@ def on_app_init(self, app_config: "AppConfig") -> "AppConfig": app_config.before_send.append(c.before_send_handler) app_config.lifespan.append(c.lifespan_handler) # pyright: ignore[reportUnknownMemberType] app_config.dependencies.update( - {c.connection_key: Provide(c.connection_provider), c.pool_key: Provide(c.pool_provider)}, + { + c.connection_key: Provide(c.connection_provider), + c.pool_key: Provide(c.pool_provider), + c.session_key: Provide(c.session_provider), + }, ) return app_config diff --git a/sqlspec/typing.py b/sqlspec/typing.py index fd3387c49..bb2a1cf8c 100644 --- a/sqlspec/typing.py +++ b/sqlspec/typing.py @@ -7,8 +7,10 @@ from sqlspec._typing import ( LITESTAR_INSTALLED, MSGSPEC_INSTALLED, + PYARROW_INSTALLED, PYDANTIC_INSTALLED, UNSET, + ArrowTable, BaseModel, DataclassProtocol, DTOData, @@ -486,9 +488,11 @@ def schema_dump( # noqa: PLR0911 __all__ = ( "LITESTAR_INSTALLED", "MSGSPEC_INSTALLED", + "PYARROW_INSTALLED", "PYDANTIC_INSTALLED", "PYDANTIC_USE_FAILFAST", "UNSET", + "ArrowTable", "BaseModel", "DataclassProtocol", "Empty", @@ -539,3 +543,8 @@ def schema_dump( # noqa: PLR0911 from sqlspec._typing import UNSET, Struct, UnsetType, convert else: from msgspec import UNSET, Struct, UnsetType, convert # noqa: TC004 + + if not PYARROW_INSTALLED: + from sqlspec._typing import ArrowTable + else: + from pyarrow import Table as ArrowTable # noqa: TC004 diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index e69de29bb..57f46b379 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytestmark = pytest.mark.integration diff --git a/tests/integration/test_adapters/test_adbc/__init__.py b/tests/integration/test_adapters/test_adbc/__init__.py index 7ad113d6f..1376b7809 100644 --- a/tests/integration/test_adapters/test_adbc/__init__.py +++ b/tests/integration/test_adapters/test_adbc/__init__.py @@ -1 +1,5 @@ """Tests for ADBC adapter with PostgreSQL.""" + +import pytest + +pytestmark = pytest.mark.adbc diff --git a/tests/integration/test_adapters/test_adbc/conftest.py b/tests/integration/test_adapters/test_adbc/conftest.py index 016b31230..5c2870231 100644 --- a/tests/integration/test_adapters/test_adbc/conftest.py +++ b/tests/integration/test_adapters/test_adbc/conftest.py @@ -7,7 +7,7 @@ import pytest from pytest_databases.docker.postgres import PostgresService -from sqlspec.adapters.adbc import Adbc +from sqlspec.adapters.adbc import AdbcConfig F = TypeVar("F", bound=Callable[..., Any]) @@ -28,8 +28,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: @pytest.fixture(scope="session") -def adbc_session(postgres_service: PostgresService) -> Adbc: +def adbc_session(postgres_service: PostgresService) -> AdbcConfig: """Create an ADBC session for PostgreSQL.""" - return Adbc( + return AdbcConfig( uri=f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}", ) diff --git a/tests/integration/test_adapters/test_adbc/test_connection.py b/tests/integration/test_adapters/test_adbc/test_connection.py index 1b38f4a74..003756bae 100644 --- a/tests/integration/test_adapters/test_adbc/test_connection.py +++ b/tests/integration/test_adapters/test_adbc/test_connection.py @@ -3,19 +3,21 @@ from __future__ import annotations +import pytest from pytest_databases.docker.postgres import PostgresService -from sqlspec.adapters.adbc import Adbc +from sqlspec.adapters.adbc import AdbcConfig # Import the decorator from tests.integration.test_adapters.test_adbc.conftest import xfail_if_driver_missing +@pytest.mark.xdist_group("postgres") @xfail_if_driver_missing def test_connection(postgres_service: PostgresService) -> None: """Test ADBC connection to PostgreSQL.""" # Test direct connection - config = Adbc( + config = AdbcConfig( uri=f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}", driver_name="adbc_driver_postgresql.dbapi.connect", ) diff --git a/tests/integration/test_adapters/test_adbc/test_driver_bigquery.py b/tests/integration/test_adapters/test_adbc/test_driver_bigquery.py index 61d515ac6..39789ba1b 100644 --- a/tests/integration/test_adapters/test_adbc/test_driver_bigquery.py +++ b/tests/integration/test_adapters/test_adbc/test_driver_bigquery.py @@ -4,18 +4,19 @@ from typing import Any, Literal +import pyarrow as pa import pytest from adbc_driver_bigquery import DatabaseOptions from pytest_databases.docker.bigquery import BigQueryService -from sqlspec.adapters.adbc import Adbc +from sqlspec.adapters.adbc import AdbcConfig from tests.integration.test_adapters.test_adbc.conftest import xfail_if_driver_missing ParamStyle = Literal["tuple_binds", "dict_binds"] -@pytest.fixture(scope="session") -def adbc_session(bigquery_service: BigQueryService) -> Adbc: +@pytest.fixture +def adbc_session(bigquery_service: BigQueryService) -> AdbcConfig: """Create an ADBC session for BigQuery.""" db_kwargs = { DatabaseOptions.PROJECT_ID.value: bigquery_service.project, @@ -23,15 +24,7 @@ def adbc_session(bigquery_service: BigQueryService) -> Adbc: DatabaseOptions.AUTH_TYPE.value: DatabaseOptions.AUTH_VALUE_BIGQUERY.value, } - return Adbc(driver_name="adbc_driver_bigquery", db_kwargs=db_kwargs) - - -@pytest.fixture(autouse=True) -def cleanup_test_table(adbc_session: Adbc) -> None: - """Clean up the test table before each test.""" - with adbc_session.provide_session() as driver: - # Using IF EXISTS is generally safer for cleanup - driver.execute_script("DROP TABLE IF EXISTS test_table") + return AdbcConfig(driver_name="adbc_driver_bigquery", db_kwargs=db_kwargs) @pytest.mark.parametrize( @@ -43,7 +36,8 @@ def cleanup_test_table(adbc_session: Adbc) -> None: ) @xfail_if_driver_missing @pytest.mark.xfail(reason="BigQuery emulator may cause failures") -def test_driver_select(adbc_session: Adbc, params: Any, style: ParamStyle, insert_id: int) -> None: +@pytest.mark.xdist_group("bigquery") +def test_driver_select(adbc_session: AdbcConfig, params: Any, style: ParamStyle, insert_id: int) -> None: """Test select functionality with different parameter styles.""" with adbc_session.provide_session() as driver: # Create test table (Use BigQuery compatible types) @@ -73,6 +67,7 @@ def test_driver_select(adbc_session: Adbc, params: Any, style: ParamStyle, inser results = driver.select(select_sql, select_params) assert len(results) == 1 assert results[0]["name"] == expected_name + driver.execute_script("DROP TABLE IF EXISTS test_table") @pytest.mark.parametrize( @@ -84,7 +79,8 @@ def test_driver_select(adbc_session: Adbc, params: Any, style: ParamStyle, inser ) @xfail_if_driver_missing @pytest.mark.xfail(reason="BigQuery emulator may cause failures") -def test_driver_select_value(adbc_session: Adbc, params: Any, style: ParamStyle, insert_id: int) -> None: +@pytest.mark.xdist_group("bigquery") +def test_driver_select_value(adbc_session: AdbcConfig, params: Any, style: ParamStyle, insert_id: int) -> None: """Test select_value functionality with different parameter styles.""" with adbc_session.provide_session() as driver: # Create test table @@ -113,11 +109,13 @@ def test_driver_select_value(adbc_session: Adbc, params: Any, style: ParamStyle, # Select and verify value = driver.select_value(select_sql, select_params) assert value == expected_name + driver.execute_script("DROP TABLE IF EXISTS test_table") @xfail_if_driver_missing @pytest.mark.xfail(reason="BigQuery emulator may cause failures") -def test_driver_insert(adbc_session: Adbc) -> None: +@pytest.mark.xdist_group("bigquery") +def test_driver_insert(adbc_session: AdbcConfig) -> None: """Test insert functionality using positional parameters.""" with adbc_session.provide_session() as driver: # Create test table @@ -140,11 +138,13 @@ def test_driver_insert(adbc_session: Adbc) -> None: results = driver.select("SELECT name FROM test_table WHERE id = ?", (1,)) assert len(results) == 1 assert results[0]["name"] == "test_insert" + driver.execute_script("DROP TABLE IF EXISTS test_table") @xfail_if_driver_missing @pytest.mark.xfail(reason="BigQuery emulator may cause failures") -def test_driver_select_normal(adbc_session: Adbc) -> None: +@pytest.mark.xdist_group("bigquery") +def test_driver_select_normal(adbc_session: AdbcConfig) -> None: """Test select functionality using positional parameters.""" with adbc_session.provide_session() as driver: # Create test table @@ -165,11 +165,13 @@ def test_driver_select_normal(adbc_session: Adbc) -> None: results = driver.select(select_sql, (10,)) assert len(results) == 1 assert results[0]["name"] == "test_select_normal" + driver.execute_script("DROP TABLE IF EXISTS test_table") @xfail_if_driver_missing @pytest.mark.xfail(reason="BigQuery emulator may cause failures") -def test_execute_script_multiple_statements(adbc_session: Adbc) -> None: +@pytest.mark.xdist_group("bigquery") +def test_execute_script_multiple_statements(adbc_session: AdbcConfig) -> None: """Test execute_script with multiple statements.""" with adbc_session.provide_session() as driver: script = """ @@ -188,3 +190,38 @@ def test_execute_script_multiple_statements(adbc_session: Adbc) -> None: value = driver.select_value("SELECT name FROM test_table WHERE id = ?", (1,)) assert value == "script_test" + driver.execute_script("DROP TABLE IF EXISTS test_table") + + +@xfail_if_driver_missing +@pytest.mark.xfail(reason="BigQuery emulator may cause failures") +@pytest.mark.xdist_group("bigquery") +def test_driver_select_arrow(adbc_session: AdbcConfig) -> None: + """Test select_arrow functionality for ADBC BigQuery.""" + with adbc_session.provide_session() as driver: + # Create test table + sql = """ + CREATE TABLE test_table ( + id INT64, + name STRING + ); + """ + driver.execute_script(sql) + + # Insert test record using positional parameters (?) + insert_sql = "INSERT INTO test_table (id, name) VALUES (?, ?)" + driver.insert_update_delete(insert_sql, (100, "arrow_name")) + + # Select and verify with select_arrow using positional parameters (?) + select_sql = "SELECT name, id FROM test_table WHERE name = ?" + arrow_table = driver.select_arrow(select_sql, ("arrow_name",)) + + assert isinstance(arrow_table, pa.Table) + assert arrow_table.num_rows == 1 + assert arrow_table.num_columns == 2 + # BigQuery might not guarantee column order, sort for check + assert sorted(arrow_table.column_names) == sorted(["name", "id"]) + # Check data irrespective of column order + assert arrow_table.column("name").to_pylist() == ["arrow_name"] + assert arrow_table.column("id").to_pylist() == [100] + driver.execute_script("DROP TABLE IF EXISTS test_table") diff --git a/tests/integration/test_adapters/test_adbc/test_driver_duckdb.py b/tests/integration/test_adapters/test_adbc/test_driver_duckdb.py index 61687fd0e..23912f19d 100644 --- a/tests/integration/test_adapters/test_adbc/test_driver_duckdb.py +++ b/tests/integration/test_adapters/test_adbc/test_driver_duckdb.py @@ -4,9 +4,10 @@ from typing import Any, Literal +import pyarrow as pa import pytest -from sqlspec.adapters.adbc import Adbc +from sqlspec.adapters.adbc import AdbcConfig # Import the decorator from tests.integration.test_adapters.test_adbc.conftest import xfail_if_driver_missing @@ -14,22 +15,14 @@ ParamStyle = Literal["tuple_binds", "dict_binds"] -@pytest.fixture(scope="session") -def adbc_session() -> Adbc: +@pytest.fixture +def adbc_session() -> AdbcConfig: """Create an ADBC session for DuckDB using URI.""" - return Adbc( + return AdbcConfig( uri="duckdb://:memory:", ) -@pytest.fixture(autouse=True) -def cleanup_test_table(adbc_session: Adbc) -> None: - """Clean up the test table before each test.""" - with adbc_session.provide_session() as driver: - driver.execute_script("DROP TABLE IF EXISTS test_table") - driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") - - @pytest.mark.parametrize( ("params", "style"), [ @@ -38,7 +31,8 @@ def cleanup_test_table(adbc_session: Adbc) -> None: ], ) @xfail_if_driver_missing -def test_driver_insert_returning(adbc_session: Adbc, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("duckdb") +def test_driver_insert_returning(adbc_session: AdbcConfig, params: Any, style: ParamStyle) -> None: """Test insert returning functionality with different parameter styles.""" with adbc_session.provide_session() as driver: create_sequence_sql = "CREATE SEQUENCE test_table_id_seq START 1;" @@ -61,6 +55,8 @@ def test_driver_insert_returning(adbc_session: Adbc, params: Any, style: ParamSt assert result is not None assert result["name"] == "test_name" assert result["id"] is not None + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") @pytest.mark.parametrize( @@ -71,7 +67,8 @@ def test_driver_insert_returning(adbc_session: Adbc, params: Any, style: ParamSt ], ) @xfail_if_driver_missing -def test_driver_select(adbc_session: Adbc, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("duckdb") +def test_driver_select(adbc_session: AdbcConfig, params: Any, style: ParamStyle) -> None: """Test select functionality with different parameter styles.""" with adbc_session.provide_session() as driver: # Create test table @@ -99,6 +96,8 @@ def test_driver_select(adbc_session: Adbc, params: Any, style: ParamStyle) -> No results = driver.select(select_sql, params) assert len(results) == 1 assert results[0]["name"] == "test_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") @pytest.mark.parametrize( @@ -109,7 +108,8 @@ def test_driver_select(adbc_session: Adbc, params: Any, style: ParamStyle) -> No ], ) @xfail_if_driver_missing -def test_driver_select_value(adbc_session: Adbc, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("duckdb") +def test_driver_select_value(adbc_session: AdbcConfig, params: Any, style: ParamStyle) -> None: """Test select_value functionality with different parameter styles.""" with adbc_session.provide_session() as driver: # Create test table @@ -136,10 +136,13 @@ def test_driver_select_value(adbc_session: Adbc, params: Any, style: ParamStyle) """ % ("$1" if style == "tuple_binds" else ":name") value = driver.select_value(select_sql, params) assert value == "test_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") @xfail_if_driver_missing -def test_driver_insert(adbc_session: Adbc) -> None: +@pytest.mark.xdist_group("duckdb") +def test_driver_insert(adbc_session: AdbcConfig) -> None: """Test insert functionality.""" with adbc_session.provide_session() as driver: # Create test table @@ -160,10 +163,13 @@ def test_driver_insert(adbc_session: Adbc) -> None: """ row_count = driver.insert_update_delete(insert_sql, ("test_name",)) assert row_count in (0, 1, -1) + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") @xfail_if_driver_missing -def test_driver_select_normal(adbc_session: Adbc) -> None: +@pytest.mark.xdist_group("duckdb") +def test_driver_select_normal(adbc_session: AdbcConfig) -> None: """Test select functionality.""" with adbc_session.provide_session() as driver: # Create test table @@ -189,6 +195,8 @@ def test_driver_select_normal(adbc_session: Adbc) -> None: results = driver.select(select_sql, {"name": "test_name"}) assert len(results) == 1 assert results[0]["name"] == "test_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") @pytest.mark.parametrize( @@ -200,7 +208,8 @@ def test_driver_select_normal(adbc_session: Adbc) -> None: ], ) @xfail_if_driver_missing -def test_param_styles(adbc_session: Adbc, param_style: str) -> None: +@pytest.mark.xdist_group("duckdb") +def test_param_styles(adbc_session: AdbcConfig, param_style: str) -> None: """Test different parameter styles.""" with adbc_session.provide_session() as driver: # Create test table @@ -226,3 +235,44 @@ def test_param_styles(adbc_session: Adbc, param_style: str) -> None: results = driver.select(select_sql, ("test_name",)) assert len(results) == 1 assert results[0]["name"] == "test_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("duckdb") +def test_driver_select_arrow(adbc_session: AdbcConfig) -> None: + """Test select_arrow functionality for ADBC DuckDB.""" + with adbc_session.provide_session() as driver: + # Create test table + create_sequence_sql = "CREATE SEQUENCE test_table_id_seq START 1;" + driver.execute_script(create_sequence_sql) + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY DEFAULT nextval('test_table_id_seq'), + name VARCHAR(50) + ); + """ + driver.execute_script(sql) + + # Insert test record using a known param style ($1 for duckdb) + insert_sql = """ + INSERT INTO test_table (name) + VALUES ($1) + """ + driver.insert_update_delete(insert_sql, ("arrow_name",)) + + # Select and verify with select_arrow using a known param style + select_sql = "SELECT name, id FROM test_table WHERE name = $1" + arrow_table = driver.select_arrow(select_sql, ("arrow_name",)) + + assert isinstance(arrow_table, pa.Table) + assert arrow_table.num_rows == 1 + assert arrow_table.num_columns == 2 + # DuckDB should return columns in selected order + assert arrow_table.column_names == ["name", "id"] + assert arrow_table.column("name").to_pylist() == ["arrow_name"] + # Assuming id is 1 for the inserted record + assert arrow_table.column("id").to_pylist() == [1] + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") diff --git a/tests/integration/test_adapters/test_adbc/test_driver_postgres.py b/tests/integration/test_adapters/test_adbc/test_driver_postgres.py index ec33c5121..b4544a537 100644 --- a/tests/integration/test_adapters/test_adbc/test_driver_postgres.py +++ b/tests/integration/test_adapters/test_adbc/test_driver_postgres.py @@ -5,22 +5,23 @@ from collections.abc import Generator from typing import Any, Literal +import pyarrow as pa import pytest from pytest_databases.docker.postgres import PostgresService -from sqlspec.adapters.adbc import Adbc, AdbcDriver +from sqlspec.adapters.adbc import AdbcConfig, AdbcDriver ParamStyle = Literal["tuple_binds", "dict_binds"] -@pytest.fixture(scope="session") +@pytest.fixture def adbc_postgres_session(postgres_service: PostgresService) -> Generator[AdbcDriver, None, None]: """Create an ADBC postgres session with a test table. Returns: A configured ADBC postgres session with a test table. """ - adapter = Adbc( + adapter = AdbcConfig( uri=f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}", ) try: @@ -48,6 +49,7 @@ def adbc_postgres_session(postgres_service: PostgresService) -> Generator[AdbcDr pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("postgres") def test_insert_update_delete_returning(adbc_postgres_session: AdbcDriver, params: Any, style: ParamStyle) -> None: """Test insert_update_delete_returning with different parameter styles.""" # Clear table before test @@ -80,6 +82,7 @@ def test_insert_update_delete_returning(adbc_postgres_session: AdbcDriver, param pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("postgres") def test_select(adbc_postgres_session: AdbcDriver, params: Any, style: ParamStyle) -> None: # pyright: ignore """Test select functionality with different parameter styles.""" # Clear table before test @@ -111,6 +114,7 @@ def test_select(adbc_postgres_session: AdbcDriver, params: Any, style: ParamStyl pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("postgres") def test_select_one(adbc_postgres_session: AdbcDriver, params: Any, style: ParamStyle) -> None: """Test select_one functionality with different parameter styles.""" # Clear table before test @@ -142,6 +146,7 @@ def test_select_one(adbc_postgres_session: AdbcDriver, params: Any, style: Param pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("postgres") def test_select_value(adbc_postgres_session: AdbcDriver, params: Any, style: ParamStyle) -> None: """Test select_value functionality with different parameter styles.""" # Clear table before test @@ -165,3 +170,32 @@ def test_select_value(adbc_postgres_session: AdbcDriver, params: Any, style: Par value = adbc_postgres_session.select_value(sql, select_params) assert value == "test_name" + + +@pytest.mark.xdist_group("postgres") +def test_select_arrow(adbc_postgres_session: AdbcDriver) -> None: + """Test select_arrow functionality for ADBC Postgres.""" + # Clear table before test + adbc_postgres_session.execute_script("DELETE FROM test_table", None) + + # Insert test record using $1 param style + insert_sql = """ + INSERT INTO test_table (name) + VALUES ($1) + """ + adbc_postgres_session.insert_update_delete(insert_sql, ("arrow_name",)) + + # Select and verify with select_arrow using $1 param style + select_sql = "SELECT name, id FROM test_table WHERE name = $1" + arrow_table = adbc_postgres_session.select_arrow(select_sql, ("arrow_name",)) + + assert isinstance(arrow_table, pa.Table) + assert arrow_table.num_rows == 1 + assert arrow_table.num_columns == 2 + # Postgres should return columns in selected order + assert arrow_table.column_names == ["name", "id"] + assert arrow_table.column("name").to_pylist() == ["arrow_name"] + # Assuming id is 1 for the inserted record (check might need adjustment if SERIAL doesn't guarantee 1) + # Let's check type and existence instead of exact value + assert arrow_table.column("id").to_pylist()[0] is not None + assert isinstance(arrow_table.column("id").to_pylist()[0], int) diff --git a/tests/integration/test_adapters/test_adbc/test_driver_sqlite.py b/tests/integration/test_adapters/test_adbc/test_driver_sqlite.py index ec720f7d9..fa7c5ea7e 100644 --- a/tests/integration/test_adapters/test_adbc/test_driver_sqlite.py +++ b/tests/integration/test_adapters/test_adbc/test_driver_sqlite.py @@ -4,9 +4,10 @@ from typing import Any, Literal +import pyarrow as pa import pytest -from sqlspec.adapters.adbc import Adbc +from sqlspec.adapters.adbc import AdbcConfig # Import the decorator from tests.integration.test_adapters.test_adbc.conftest import xfail_if_driver_missing @@ -14,21 +15,14 @@ ParamStyle = Literal["tuple_binds", "dict_binds"] -@pytest.fixture(scope="session") -def adbc_session() -> Adbc: +@pytest.fixture +def adbc_session() -> AdbcConfig: """Create an ADBC session for SQLite using URI.""" - return Adbc( + return AdbcConfig( uri="sqlite://:memory:", ) -@pytest.fixture(autouse=True) -def cleanup_test_table(adbc_session: Adbc) -> None: - """Clean up the test table before each test.""" - with adbc_session.provide_session() as driver: - driver.execute_script("DROP TABLE IF EXISTS test_table") - - @pytest.mark.parametrize( ("params", "style"), [ @@ -37,7 +31,8 @@ def cleanup_test_table(adbc_session: Adbc) -> None: ], ) @xfail_if_driver_missing -def test_driver_insert_returning(adbc_session: Adbc, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("sqlite") +def test_driver_insert_returning(adbc_session: AdbcConfig, params: Any, style: ParamStyle) -> None: """Test insert returning functionality with different parameter styles.""" with adbc_session.provide_session() as driver: sql = """ @@ -68,9 +63,12 @@ def test_driver_insert_returning(adbc_session: Adbc, params: Any, style: ParamSt assert result["name"] == "test_name" assert result["id"] is not None + driver.execute_script("DROP TABLE IF EXISTS test_table") + @xfail_if_driver_missing -def test_driver_select(adbc_session: Adbc) -> None: +@pytest.mark.xdist_group("sqlite") +def test_driver_select(adbc_session: AdbcConfig) -> None: """Test select functionality with simple tuple parameters.""" params = ("test_name",) with adbc_session.provide_session() as driver: @@ -93,9 +91,12 @@ def test_driver_select(adbc_session: Adbc) -> None: assert len(results) == 1 assert results[0]["name"] == "test_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") + @xfail_if_driver_missing -def test_driver_select_value(adbc_session: Adbc) -> None: +@pytest.mark.xdist_group("sqlite") +def test_driver_select_value(adbc_session: AdbcConfig) -> None: """Test select_value functionality with simple tuple parameters.""" params = ("test_name",) with adbc_session.provide_session() as driver: @@ -117,9 +118,12 @@ def test_driver_select_value(adbc_session: Adbc) -> None: value = driver.select_value(select_sql, params) assert value == "test_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") + @xfail_if_driver_missing -def test_driver_insert(adbc_session: Adbc) -> None: +@pytest.mark.xdist_group("sqlite") +def test_driver_insert(adbc_session: AdbcConfig) -> None: """Test insert functionality.""" with adbc_session.provide_session() as driver: # Create test table @@ -139,9 +143,12 @@ def test_driver_insert(adbc_session: Adbc) -> None: row_count = driver.insert_update_delete(insert_sql, ("test_name",)) assert row_count == 1 or row_count == -1 + driver.execute_script("DROP TABLE IF EXISTS test_table") + @xfail_if_driver_missing -def test_driver_select_normal(adbc_session: Adbc) -> None: +@pytest.mark.xdist_group("sqlite") +def test_driver_select_normal(adbc_session: AdbcConfig) -> None: """Test select functionality.""" with adbc_session.provide_session() as driver: # Create test table @@ -166,6 +173,8 @@ def test_driver_select_normal(adbc_session: Adbc) -> None: assert len(results) == 1 assert results[0]["name"] == "test_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") + @pytest.mark.parametrize( "param_style", @@ -176,7 +185,8 @@ def test_driver_select_normal(adbc_session: Adbc) -> None: ], ) @xfail_if_driver_missing -def test_param_styles(adbc_session: Adbc, param_style: str) -> None: +@pytest.mark.xdist_group("sqlite") +def test_param_styles(adbc_session: AdbcConfig, param_style: str) -> None: """Test different parameter styles.""" with adbc_session.provide_session() as driver: # Create test table @@ -200,3 +210,43 @@ def test_param_styles(adbc_session: Adbc, param_style: str) -> None: results = driver.select(select_sql, ("test_name",)) assert len(results) == 1 assert results[0]["name"] == "test_name" + + driver.execute_script("DROP TABLE IF EXISTS test_table") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("sqlite") +def test_driver_select_arrow(adbc_session: AdbcConfig) -> None: + """Test select_arrow functionality.""" + with adbc_session.provide_session() as driver: + # Create test table + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(50) + ); + """ + driver.execute_script(sql) + + # Insert test record + insert_sql = """ + INSERT INTO test_table (name) + VALUES (?) + """ + driver.insert_update_delete(insert_sql, ("arrow_name",)) + + # Select and verify with select_arrow + select_sql = "SELECT name, id FROM test_table WHERE name = ?" + arrow_table = driver.select_arrow(select_sql, ("arrow_name",)) + + assert isinstance(arrow_table, pa.Table) + assert arrow_table.num_rows == 1 + assert arrow_table.num_columns == 2 + # Note: Column order might vary depending on DB/driver, adjust if needed + # Sorting column names for consistent check + assert sorted(arrow_table.column_names) == sorted(["name", "id"]) + # Check data irrespective of column order + assert arrow_table.column("name").to_pylist() == ["arrow_name"] + # Assuming id is 1 for the inserted record + assert arrow_table.column("id").to_pylist() == [1] + driver.execute_script("DROP TABLE IF EXISTS test_table") diff --git a/tests/integration/test_adapters/test_aiosqlite/__init__.py b/tests/integration/test_adapters/test_aiosqlite/__init__.py index f1305e919..c79d8c054 100644 --- a/tests/integration/test_adapters/test_aiosqlite/__init__.py +++ b/tests/integration/test_adapters/test_aiosqlite/__init__.py @@ -1 +1,5 @@ """Integration tests for sqlspec adapters.""" + +import pytest + +pytestmark = [pytest.mark.sqlite, pytest.mark.aiosqlite] diff --git a/tests/integration/test_adapters/test_aiosqlite/conftest.py b/tests/integration/test_adapters/test_aiosqlite/conftest.py deleted file mode 100644 index 2f8615e2d..000000000 --- a/tests/integration/test_adapters/test_aiosqlite/conftest.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - -import asyncio -from collections.abc import Generator - -import pytest - - -@pytest.fixture(scope="session") -def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: - """Create an instance of the default event loop for each test case.""" - import asyncio - - loop = asyncio.new_event_loop() - yield loop - loop.close() diff --git a/tests/integration/test_adapters/test_aiosqlite/test_connection.py b/tests/integration/test_adapters/test_aiosqlite/test_connection.py index 4a05e065e..9f851cef6 100644 --- a/tests/integration/test_adapters/test_aiosqlite/test_connection.py +++ b/tests/integration/test_adapters/test_aiosqlite/test_connection.py @@ -2,14 +2,15 @@ import pytest -from sqlspec.adapters.aiosqlite import Aiosqlite +from sqlspec.adapters.aiosqlite import AiosqliteConfig +@pytest.mark.xdist_group("sqlite") @pytest.mark.asyncio async def test_connection() -> None: """Test connection components.""" # Test direct connection - config = Aiosqlite() + config = AiosqliteConfig() async with config.provide_connection() as conn: assert conn is not None diff --git a/tests/integration/test_adapters/test_aiosqlite/test_driver.py b/tests/integration/test_adapters/test_aiosqlite/test_driver.py index 79ff75c87..3c49e92a4 100644 --- a/tests/integration/test_adapters/test_aiosqlite/test_driver.py +++ b/tests/integration/test_adapters/test_aiosqlite/test_driver.py @@ -7,22 +7,21 @@ from typing import Any, Literal import pytest -import pytest_asyncio -from sqlspec.adapters.aiosqlite import Aiosqlite, AiosqliteDriver +from sqlspec.adapters.aiosqlite import AiosqliteConfig, AiosqliteDriver from tests.fixtures.sql_utils import create_tuple_or_dict_params, format_sql ParamStyle = Literal["tuple_binds", "dict_binds"] -@pytest_asyncio.fixture(scope="session") +@pytest.fixture async def aiosqlite_session() -> AsyncGenerator[AiosqliteDriver, None]: """Create a SQLite session with a test table. Returns: A configured SQLite session with a test table. """ - adapter = Aiosqlite() + adapter = AiosqliteConfig() create_table_sql = """ CREATE TABLE IF NOT EXISTS test_table ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -36,12 +35,6 @@ async def aiosqlite_session() -> AsyncGenerator[AiosqliteDriver, None]: await session.execute_script("DROP TABLE IF EXISTS test_table", None) -@pytest_asyncio.fixture(autouse=True) -async def cleanup_table(aiosqlite_session: AiosqliteDriver) -> None: - """Clean up the test table before each test.""" - await aiosqlite_session.execute_script("DELETE FROM test_table", None) - - @pytest.mark.parametrize( ("params", "style"), [ @@ -49,6 +42,7 @@ async def cleanup_table(aiosqlite_session: AiosqliteDriver) -> None: pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("sqlite") @pytest.mark.asyncio async def test_insert_update_delete_returning( aiosqlite_session: AiosqliteDriver, params: Any, style: ParamStyle @@ -70,20 +64,7 @@ async def test_insert_update_delete_returning( assert result is not None assert result["name"] == "test_name" assert result["id"] is not None - else: - # Alternative for older SQLite: Insert and then get last row id - sql_template = """ - INSERT INTO test_table (name) - VALUES ({}) - """ - sql = format_sql(sql_template, ["name"], style, "aiosqlite") - - await aiosqlite_session.insert_update_delete(sql, params) - - # Get the last inserted ID using select_value - select_last_id_sql = "SELECT last_insert_rowid()" - inserted_id = await aiosqlite_session.select_value(select_last_id_sql) - assert inserted_id is not None + await aiosqlite_session.execute_script("DELETE FROM test_table") @pytest.mark.parametrize( @@ -93,6 +74,7 @@ async def test_insert_update_delete_returning( pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("sqlite") @pytest.mark.asyncio async def test_select(aiosqlite_session: AiosqliteDriver, params: Any, style: ParamStyle) -> None: """Test select functionality with different parameter styles.""" @@ -110,6 +92,7 @@ async def test_select(aiosqlite_session: AiosqliteDriver, params: Any, style: Pa results = await aiosqlite_session.select(select_sql, empty_params) assert len(results) == 1 assert results[0]["name"] == "test_name" + await aiosqlite_session.execute_script("DELETE FROM test_table") @pytest.mark.parametrize( @@ -119,6 +102,7 @@ async def test_select(aiosqlite_session: AiosqliteDriver, params: Any, style: Pa pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("sqlite") @pytest.mark.asyncio async def test_select_one(aiosqlite_session: AiosqliteDriver, params: Any, style: ParamStyle) -> None: """Test select_one functionality with different parameter styles.""" @@ -141,6 +125,7 @@ async def test_select_one(aiosqlite_session: AiosqliteDriver, params: Any, style result = await aiosqlite_session.select_one(sql, select_params) assert result is not None assert result["name"] == "test_name" + await aiosqlite_session.execute_script("DELETE FROM test_table") @pytest.mark.parametrize( @@ -150,6 +135,7 @@ async def test_select_one(aiosqlite_session: AiosqliteDriver, params: Any, style pytest.param({"name": "test_name"}, {"id": 1}, "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("sqlite") @pytest.mark.asyncio async def test_select_value( aiosqlite_session: AiosqliteDriver, @@ -179,3 +165,4 @@ async def test_select_value( test_id_params = create_tuple_or_dict_params([inserted_id], ["id"], style) value = await aiosqlite_session.select_value(sql, test_id_params) assert value == "test_name" + await aiosqlite_session.execute_script("DELETE FROM test_table") diff --git a/tests/integration/test_adapters/test_asyncmy/__init__.py b/tests/integration/test_adapters/test_asyncmy/__init__.py index e69de29bb..4af6321e1 100644 --- a/tests/integration/test_adapters/test_asyncmy/__init__.py +++ b/tests/integration/test_adapters/test_asyncmy/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytestmark = [pytest.mark.mysql, pytest.mark.asyncmy] diff --git a/tests/integration/test_adapters/test_asyncmy/test_connection.py b/tests/integration/test_adapters/test_asyncmy/test_connection.py index dd1fd9160..fd70e0125 100644 --- a/tests/integration/test_adapters/test_asyncmy/test_connection.py +++ b/tests/integration/test_adapters/test_asyncmy/test_connection.py @@ -1,16 +1,17 @@ import pytest from pytest_databases.docker.mysql import MySQLService -from sqlspec.adapters.asyncmy import Asyncmy, AsyncmyPool +from sqlspec.adapters.asyncmy import AsyncmyConfig, AsyncmyPoolConfig pytestmark = pytest.mark.asyncio(loop_scope="session") +@pytest.mark.xdist_group("mysql") async def test_async_connection(mysql_service: MySQLService) -> None: """Test async connection components.""" # Test direct connection - async_config = Asyncmy( - pool_config=AsyncmyPool( + async_config = AsyncmyConfig( + pool_config=AsyncmyPoolConfig( host=mysql_service.host, port=mysql_service.port, user=mysql_service.user, @@ -28,7 +29,7 @@ async def test_async_connection(mysql_service: MySQLService) -> None: assert result == (1,) # Test connection pool - pool_config = AsyncmyPool( + pool_config = AsyncmyPoolConfig( host=mysql_service.host, port=mysql_service.port, user=mysql_service.user, @@ -37,7 +38,7 @@ async def test_async_connection(mysql_service: MySQLService) -> None: minsize=1, maxsize=5, ) - another_config = Asyncmy(pool_config=pool_config) + another_config = AsyncmyConfig(pool_config=pool_config) pool = await another_config.create_pool() assert pool is not None try: diff --git a/tests/integration/test_adapters/test_asyncmy/test_driver.py b/tests/integration/test_adapters/test_asyncmy/test_driver.py index 0fffd87bd..77a6a3921 100644 --- a/tests/integration/test_adapters/test_asyncmy/test_driver.py +++ b/tests/integration/test_adapters/test_asyncmy/test_driver.py @@ -7,15 +7,15 @@ import pytest from pytest_databases.docker.mysql import MySQLService -from sqlspec.adapters.asyncmy import Asyncmy, AsyncmyPool +from sqlspec.adapters.asyncmy import AsyncmyConfig, AsyncmyPoolConfig ParamStyle = Literal["tuple_binds", "dict_binds"] pytestmark = pytest.mark.asyncio(loop_scope="session") -@pytest.fixture(scope="session") -def asyncmy_session(mysql_service: MySQLService) -> Asyncmy: +@pytest.fixture +def asyncmy_session(mysql_service: MySQLService) -> AsyncmyConfig: """Create an Asyncmy asynchronous session. Args: @@ -24,8 +24,8 @@ def asyncmy_session(mysql_service: MySQLService) -> Asyncmy: Returns: Configured Asyncmy asynchronous session. """ - return Asyncmy( - pool_config=AsyncmyPool( + return AsyncmyConfig( + pool_config=AsyncmyPoolConfig( host=mysql_service.host, port=mysql_service.port, user=mysql_service.user, @@ -43,7 +43,8 @@ def asyncmy_session(mysql_service: MySQLService) -> Asyncmy: ], ) @pytest.mark.xfail(reason="MySQL/Asyncmy does not support RETURNING clause directly") -async def test_async_insert_returning(asyncmy_session: Asyncmy, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("mysql") +async def test_async_insert_returning(asyncmy_session: AsyncmyConfig, params: Any, style: ParamStyle) -> None: """Test async insert returning functionality with different parameter styles.""" async with asyncmy_session.provide_session() as driver: # Manual cleanup at start of test @@ -82,7 +83,8 @@ async def test_async_insert_returning(asyncmy_session: Asyncmy, params: Any, sty pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) -async def test_async_select(asyncmy_session: Asyncmy, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("mysql") +async def test_async_select(asyncmy_session: AsyncmyConfig, params: Any, style: ParamStyle) -> None: """Test async select functionality with different parameter styles.""" async with asyncmy_session.provide_session() as driver: # Manual cleanup at start of test @@ -127,7 +129,8 @@ async def test_async_select(asyncmy_session: Asyncmy, params: Any, style: ParamS pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) -async def test_async_select_value(asyncmy_session: Asyncmy, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("mysql") +async def test_async_select_value(asyncmy_session: AsyncmyConfig, params: Any, style: ParamStyle) -> None: """Test async select_value functionality with different parameter styles.""" async with asyncmy_session.provide_session() as driver: # Manual cleanup at start of test @@ -163,7 +166,8 @@ async def test_async_select_value(asyncmy_session: Asyncmy, params: Any, style: assert value == "test_name" -async def test_insert(asyncmy_session: Asyncmy) -> None: +@pytest.mark.xdist_group("mysql") +async def test_insert(asyncmy_session: AsyncmyConfig) -> None: """Test inserting data.""" async with asyncmy_session.provide_session() as driver: # Manual cleanup at start of test @@ -185,7 +189,8 @@ async def test_insert(asyncmy_session: Asyncmy) -> None: assert row_count == 1 -async def test_select(asyncmy_session: Asyncmy) -> None: +@pytest.mark.xdist_group("mysql") +async def test_select(asyncmy_session: AsyncmyConfig) -> None: """Test selecting data.""" async with asyncmy_session.provide_session() as driver: # Manual cleanup at start of test diff --git a/tests/integration/test_adapters/test_duckdb/__init__.py b/tests/integration/test_adapters/test_duckdb/__init__.py index f1305e919..1084a2a71 100644 --- a/tests/integration/test_adapters/test_duckdb/__init__.py +++ b/tests/integration/test_adapters/test_duckdb/__init__.py @@ -1 +1,5 @@ """Integration tests for sqlspec adapters.""" + +import pytest + +pytestmark = [pytest.mark.duckdb, pytest.mark.duckdb_driver] diff --git a/tests/integration/test_adapters/test_duckdb/test_connection.py b/tests/integration/test_adapters/test_duckdb/test_connection.py index cbea0ffc5..727607e0f 100644 --- a/tests/integration/test_adapters/test_duckdb/test_connection.py +++ b/tests/integration/test_adapters/test_duckdb/test_connection.py @@ -1,12 +1,15 @@ """Test DuckDB connection configuration.""" -from sqlspec.adapters.duckdb.config import DuckDB +import pytest +from sqlspec.adapters.duckdb.config import DuckDBConfig + +@pytest.mark.xdist_group("duckdb") def test_connection() -> None: """Test connection components.""" # Test direct connection - config = DuckDB(database=":memory:") + config = DuckDBConfig(database=":memory:") with config.provide_connection() as conn: assert conn is not None diff --git a/tests/integration/test_adapters/test_duckdb/test_driver.py b/tests/integration/test_adapters/test_duckdb/test_driver.py index d348f3cab..38b31aa48 100644 --- a/tests/integration/test_adapters/test_duckdb/test_driver.py +++ b/tests/integration/test_adapters/test_duckdb/test_driver.py @@ -5,22 +5,23 @@ from collections.abc import Generator from typing import Any, Literal +import pyarrow as pa # Add pyarrow import import pytest -from sqlspec.adapters.duckdb import DuckDB, DuckDBDriver +from sqlspec.adapters.duckdb import DuckDBConfig, DuckDBDriver from tests.fixtures.sql_utils import create_tuple_or_dict_params, format_placeholder, format_sql ParamStyle = Literal["tuple_binds", "dict_binds"] -@pytest.fixture(scope="session") +@pytest.fixture def duckdb_session() -> Generator[DuckDBDriver, None, None]: """Create a DuckDB session with a test table. Returns: A DuckDB session with a test table. """ - adapter = DuckDB() + adapter = DuckDBConfig() with adapter.provide_session() as session: session.execute_script("CREATE SEQUENCE IF NOT EXISTS test_id_seq START 1", None) create_table_sql = """ @@ -36,12 +37,6 @@ def duckdb_session() -> Generator[DuckDBDriver, None, None]: session.execute_script("DROP SEQUENCE IF EXISTS test_id_seq", None) -@pytest.fixture(autouse=True) -def cleanup_table(duckdb_session: DuckDBDriver) -> None: - """Clean up the test table before each test.""" - duckdb_session.execute_script("DELETE FROM test_table", None) - - @pytest.mark.parametrize( ("params", "style"), [ @@ -49,6 +44,7 @@ def cleanup_table(duckdb_session: DuckDBDriver) -> None: pytest.param([{"name": "test_name", "id": 1}], "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("duckdb") def test_insert(duckdb_session: DuckDBDriver, params: list[Any], style: ParamStyle) -> None: """Test inserting data with different parameter styles.""" # DuckDB supports multiple inserts at once @@ -68,6 +64,7 @@ def test_insert(duckdb_session: DuckDBDriver, params: list[Any], style: ParamSty assert len(results) == 1 assert results[0]["name"] == "test_name" assert results[0]["id"] == 1 + duckdb_session.execute_script("DELETE FROM test_table", None) @pytest.mark.parametrize( @@ -77,6 +74,7 @@ def test_insert(duckdb_session: DuckDBDriver, params: list[Any], style: ParamSty pytest.param([{"name": "test_name", "id": 1}], "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("duckdb") def test_select(duckdb_session: DuckDBDriver, params: list[Any], style: ParamStyle) -> None: """Test selecting data with different parameter styles.""" # Insert test record @@ -105,6 +103,7 @@ def test_select(duckdb_session: DuckDBDriver, params: list[Any], style: ParamSty result = duckdb_session.select_one(select_where_sql, select_params) assert result is not None assert result["id"] == 1 + duckdb_session.execute_script("DELETE FROM test_table", None) @pytest.mark.parametrize( @@ -114,6 +113,7 @@ def test_select(duckdb_session: DuckDBDriver, params: list[Any], style: ParamSty pytest.param([{"name": "test_name", "id": 1}], "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("duckdb") def test_select_value(duckdb_session: DuckDBDriver, params: list[Any], style: ParamStyle) -> None: """Test select_value with different parameter styles.""" # Insert test record @@ -133,3 +133,37 @@ def test_select_value(duckdb_session: DuckDBDriver, params: list[Any], style: Pa value_params = create_tuple_or_dict_params([1], ["id"], style) value = duckdb_session.select_value(value_sql, value_params) assert value == "test_name" + duckdb_session.execute_script("DELETE FROM test_table", None) + + +@pytest.mark.parametrize( + ("params", "style"), + [ + pytest.param([("arrow_name", 1)], "tuple_binds", id="tuple_binds"), + pytest.param([{"name": "arrow_name", "id": 1}], "dict_binds", id="dict_binds"), + ], +) +@pytest.mark.xdist_group("duckdb") +def test_select_arrow(duckdb_session: DuckDBDriver, params: list[Any], style: ParamStyle) -> None: + """Test selecting data as an Arrow Table.""" + # Insert test record + sql_template = """ + INSERT INTO test_table (name, id) + VALUES ({}, {}) + """ + sql = format_sql(sql_template, ["name", "id"], style, "duckdb") + param = params[0] + duckdb_session.insert_update_delete(sql, param) + + # Test select_arrow + select_sql = "SELECT name, id FROM test_table WHERE id = 1" + empty_params = create_tuple_or_dict_params([], [], style) # DuckDB doesn't need params for this simple query + arrow_table = duckdb_session.select_arrow(select_sql, empty_params) + + assert isinstance(arrow_table, pa.Table) + assert arrow_table.num_rows == 1 + assert arrow_table.num_columns == 2 + assert arrow_table.column_names == ["name", "id"] + assert arrow_table.column("name").to_pylist() == ["arrow_name"] + assert arrow_table.column("id").to_pylist() == [1] + duckdb_session.execute_script("DELETE FROM test_table", None) diff --git a/tests/integration/test_adapters/test_oracledb/__init__.py b/tests/integration/test_adapters/test_oracledb/__init__.py index 1761b6f6d..d0a289b8d 100644 --- a/tests/integration/test_adapters/test_oracledb/__init__.py +++ b/tests/integration/test_adapters/test_oracledb/__init__.py @@ -1 +1,5 @@ """OracleDB adapter integration tests.""" + +import pytest + +pytestmark = [pytest.mark.oracle, pytest.mark.oracledb] diff --git a/tests/integration/test_adapters/test_oracledb/test_connection.py b/tests/integration/test_adapters/test_oracledb/test_connection.py index 69e3bf497..1169e1c4d 100644 --- a/tests/integration/test_adapters/test_oracledb/test_connection.py +++ b/tests/integration/test_adapters/test_oracledb/test_connection.py @@ -5,15 +5,14 @@ import pytest from pytest_databases.docker.oracle import OracleService -from sqlspec.adapters.oracledb import OracleAsync, OracleAsyncPool, OracleSync, OracleSyncPool - -pytestmark = pytest.mark.asyncio(loop_scope="session") +from sqlspec.adapters.oracledb import OracleAsyncConfig, OracleAsyncPoolConfig, OracleSyncConfig, OracleSyncPoolConfig +@pytest.mark.xdist_group("oracle") async def test_async_connection(oracle_23ai_service: OracleService) -> None: """Test async connection components for OracleDB.""" - async_config = OracleAsync( - pool_config=OracleAsyncPool( + async_config = OracleAsyncConfig( + pool_config=OracleAsyncPoolConfig( host=oracle_23ai_service.host, port=oracle_23ai_service.port, service_name=oracle_23ai_service.service_name, @@ -37,14 +36,14 @@ async def test_async_connection(oracle_23ai_service: OracleService) -> None: await pool.close() # Test pool re-creation and connection acquisition - pool_config = OracleAsyncPool( + pool_config = OracleAsyncPoolConfig( host=oracle_23ai_service.host, port=oracle_23ai_service.port, service_name=oracle_23ai_service.service_name, user=oracle_23ai_service.user, password=oracle_23ai_service.password, ) - another_config = OracleAsync(pool_config=pool_config) + another_config = OracleAsyncConfig(pool_config=pool_config) pool = await another_config.create_pool() assert pool is not None try: @@ -58,10 +57,11 @@ async def test_async_connection(oracle_23ai_service: OracleService) -> None: await pool.close() +@pytest.mark.xdist_group("oracle") def test_sync_connection(oracle_23ai_service: OracleService) -> None: """Test sync connection components for OracleDB.""" - sync_config = OracleSync( - pool_config=OracleSyncPool( + sync_config = OracleSyncConfig( + pool_config=OracleSyncPoolConfig( host=oracle_23ai_service.host, port=oracle_23ai_service.port, service_name=oracle_23ai_service.service_name, @@ -85,14 +85,14 @@ def test_sync_connection(oracle_23ai_service: OracleService) -> None: pool.close() # Test pool re-creation and connection acquisition - pool_config = OracleSyncPool( + pool_config = OracleSyncPoolConfig( host=oracle_23ai_service.host, port=oracle_23ai_service.port, service_name=oracle_23ai_service.service_name, user=oracle_23ai_service.user, password=oracle_23ai_service.password, ) - another_config = OracleSync(pool_config=pool_config) + another_config = OracleSyncConfig(pool_config=pool_config) pool = another_config.create_pool() assert pool is not None try: diff --git a/tests/integration/test_adapters/test_oracledb/test_driver_async.py b/tests/integration/test_adapters/test_oracledb/test_driver_async.py index e73848f46..881217c4c 100644 --- a/tests/integration/test_adapters/test_oracledb/test_driver_async.py +++ b/tests/integration/test_adapters/test_oracledb/test_driver_async.py @@ -2,13 +2,13 @@ from __future__ import annotations -from collections.abc import AsyncGenerator from typing import Any, Literal +import pyarrow as pa import pytest from pytest_databases.docker.oracle import OracleService -from sqlspec.adapters.oracledb import OracleAsync, OracleAsyncPool +from sqlspec.adapters.oracledb import OracleAsyncConfig, OracleAsyncPoolConfig ParamStyle = Literal["positional_binds", "dict_binds"] @@ -17,11 +17,11 @@ # --- Async Fixtures --- -@pytest.fixture(scope="session") -def oracle_async_session(oracle_23ai_service: OracleService) -> OracleAsync: +@pytest.fixture +def oracle_async_session(oracle_23ai_service: OracleService) -> OracleAsyncConfig: """Create an Oracle asynchronous session.""" - return OracleAsync( - pool_config=OracleAsyncPool( + return OracleAsyncConfig( + pool_config=OracleAsyncPoolConfig( host=oracle_23ai_service.host, port=oracle_23ai_service.port, service_name=oracle_23ai_service.service_name, @@ -31,21 +31,6 @@ def oracle_async_session(oracle_23ai_service: OracleService) -> OracleAsync: ) -@pytest.fixture(scope="session") -async def cleanup_async_table(oracle_async_session: OracleAsync) -> AsyncGenerator[None, None]: - """Clean up the test table before/after each async test. (Now mainly for end-of-session)""" - # Code before yield runs once before all session tests. - yield - # Code after yield runs once after all session tests. - try: - async with oracle_async_session.provide_session() as driver: - await driver.execute_script( - "BEGIN EXECUTE IMMEDIATE 'DROP TABLE test_table'; EXCEPTION WHEN OTHERS THEN IF SQLCODE != -942 THEN RAISE; END IF; END;" - ) - except Exception: - pass - - # --- Async Tests --- @@ -59,7 +44,8 @@ async def cleanup_async_table(oracle_async_session: OracleAsync) -> AsyncGenerat @pytest.mark.skip( reason="Oracle does not support RETURNING multiple columns directly in the required syntax for this method." ) -async def test_async_insert_returning(oracle_async_session: OracleAsync, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("oracle") +async def test_async_insert_returning(oracle_async_session: OracleAsyncConfig, params: Any, style: ParamStyle) -> None: """Test async insert returning functionality with Oracle parameter styles.""" async with oracle_async_session.provide_session() as driver: # Manual cleanup at start of test @@ -87,6 +73,9 @@ async def test_async_insert_returning(oracle_async_session: OracleAsync, params: assert result["NAME"] == "test_name" assert result["ID"] is not None assert isinstance(result["ID"], int) + await driver.execute_script( + "BEGIN EXECUTE IMMEDIATE 'DROP TABLE test_table'; EXCEPTION WHEN OTHERS THEN IF SQLCODE != -942 THEN RAISE; END IF; END;" + ) @pytest.mark.parametrize( @@ -96,7 +85,8 @@ async def test_async_insert_returning(oracle_async_session: OracleAsync, params: pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) -async def test_async_select(oracle_async_session: OracleAsync, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("oracle") +async def test_async_select(oracle_async_session: OracleAsyncConfig, params: Any, style: ParamStyle) -> None: """Test async select functionality with Oracle parameter styles.""" async with oracle_async_session.provide_session() as driver: # Manual cleanup at start of test @@ -128,6 +118,9 @@ async def test_async_select(oracle_async_session: OracleAsync, params: Any, styl results = await driver.select(select_sql, select_params) assert len(results) == 1 assert results[0]["NAME"] == "test_name" + await driver.execute_script( + "BEGIN EXECUTE IMMEDIATE 'DROP TABLE test_table'; EXCEPTION WHEN OTHERS THEN IF SQLCODE != -942 THEN RAISE; END IF; END;" + ) @pytest.mark.parametrize( @@ -137,7 +130,8 @@ async def test_async_select(oracle_async_session: OracleAsync, params: Any, styl pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) -async def test_async_select_value(oracle_async_session: OracleAsync, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("oracle") +async def test_async_select_value(oracle_async_session: OracleAsyncConfig, params: Any, style: ParamStyle) -> None: """Test async select_value functionality with Oracle parameter styles.""" async with oracle_async_session.provide_session() as driver: # Manual cleanup at start of test @@ -164,3 +158,44 @@ async def test_async_select_value(oracle_async_session: OracleAsync, params: Any select_sql = "SELECT 'test_value' FROM dual" value = await driver.select_value(select_sql) assert value == "test_value" + await driver.execute_script( + "BEGIN EXECUTE IMMEDIATE 'DROP TABLE test_table'; EXCEPTION WHEN OTHERS THEN IF SQLCODE != -942 THEN RAISE; END IF; END;" + ) + + +@pytest.mark.xdist_group("oracle") +async def test_async_select_arrow(oracle_async_session: OracleAsyncConfig) -> None: + """Test asynchronous select_arrow functionality.""" + async with oracle_async_session.provide_session() as driver: + # Manual cleanup at start of test + await driver.execute_script( + "BEGIN EXECUTE IMMEDIATE 'DROP TABLE test_table'; EXCEPTION WHEN OTHERS THEN IF SQLCODE != -942 THEN RAISE; END IF; END;" + ) + sql = """ + CREATE TABLE test_table ( + id NUMBER GENERATED BY DEFAULT ON NULL AS IDENTITY PRIMARY KEY, + name VARCHAR2(50) + ) + """ + await driver.execute_script(sql) + + # Insert test record using positional binds + insert_sql = "INSERT INTO test_table (name) VALUES (:1)" + await driver.insert_update_delete(insert_sql, ("arrow_name",)) + + # Select and verify with select_arrow using positional binds + select_sql = "SELECT name, id FROM test_table WHERE name = :1" + arrow_table = await driver.select_arrow(select_sql, ("arrow_name",)) + + assert isinstance(arrow_table, pa.Table) + assert arrow_table.num_rows == 1 + assert arrow_table.num_columns == 2 + # Oracle returns uppercase column names by default + assert arrow_table.column_names == ["NAME", "ID"] + assert arrow_table.column("NAME").to_pylist() == ["arrow_name"] + # Check ID exists and is a number (exact value depends on IDENTITY) + assert arrow_table.column("ID").to_pylist()[0] is not None + assert isinstance(arrow_table.column("ID").to_pylist()[0], (int, float)) # Oracle NUMBER maps to float/Decimal + await driver.execute_script( + "BEGIN EXECUTE IMMEDIATE 'DROP TABLE test_table'; EXCEPTION WHEN OTHERS THEN IF SQLCODE != -942 THEN RAISE; END IF; END;" + ) diff --git a/tests/integration/test_adapters/test_oracledb/test_driver_sync.py b/tests/integration/test_adapters/test_oracledb/test_driver_sync.py index c07025e30..9ae0ba7e9 100644 --- a/tests/integration/test_adapters/test_oracledb/test_driver_sync.py +++ b/tests/integration/test_adapters/test_oracledb/test_driver_sync.py @@ -4,21 +4,22 @@ from typing import Any, Literal +import pyarrow as pa import pytest from pytest_databases.docker.oracle import OracleService -from sqlspec.adapters.oracledb import OracleSync, OracleSyncPool +from sqlspec.adapters.oracledb import OracleSyncConfig, OracleSyncPoolConfig ParamStyle = Literal["positional_binds", "dict_binds"] # --- Sync Fixtures --- -@pytest.fixture(scope="session") -def oracle_sync_session(oracle_23ai_service: OracleService) -> OracleSync: +@pytest.fixture +def oracle_sync_session(oracle_23ai_service: OracleService) -> OracleSyncConfig: """Create an Oracle synchronous session.""" - return OracleSync( - pool_config=OracleSyncPool( + return OracleSyncConfig( + pool_config=OracleSyncPoolConfig( host=oracle_23ai_service.host, port=oracle_23ai_service.port, service_name=oracle_23ai_service.service_name, @@ -28,19 +29,6 @@ def oracle_sync_session(oracle_23ai_service: OracleService) -> OracleSync: ) -@pytest.fixture(autouse=True) -def cleanup_sync_table(oracle_sync_session: OracleSync) -> None: - """Clean up the test table after each sync test.""" - try: - with oracle_sync_session.provide_session() as driver: - # Use a block to handle potential ORA-00942: table or view does not exist - driver.execute_script( - "BEGIN EXECUTE IMMEDIATE 'DROP TABLE test_table'; EXCEPTION WHEN OTHERS THEN IF SQLCODE != -942 THEN RAISE; END IF; END;" - ) - except Exception: - pass - - # --- Sync Tests --- @@ -54,7 +42,8 @@ def cleanup_sync_table(oracle_sync_session: OracleSync) -> None: @pytest.mark.skip( reason="Oracle does not support RETURNING multiple columns directly in the required syntax for this method." ) -def test_sync_insert_returning(oracle_sync_session: OracleSync, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("oracle") +def test_sync_insert_returning(oracle_sync_session: OracleSyncConfig, params: Any, style: ParamStyle) -> None: """Test synchronous insert returning functionality with Oracle parameter styles.""" with oracle_sync_session.provide_session() as driver: sql = """ @@ -79,6 +68,9 @@ def test_sync_insert_returning(oracle_sync_session: OracleSync, params: Any, sty assert result["NAME"] == "test_name" assert result["ID"] is not None assert isinstance(result["ID"], int) + driver.execute_script( + "BEGIN EXECUTE IMMEDIATE 'DROP TABLE test_table'; EXCEPTION WHEN OTHERS THEN IF SQLCODE != -942 THEN RAISE; END IF; END;" + ) @pytest.mark.parametrize( @@ -88,7 +80,8 @@ def test_sync_insert_returning(oracle_sync_session: OracleSync, params: Any, sty pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) -def test_sync_select(oracle_sync_session: OracleSync, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("oracle") +def test_sync_select(oracle_sync_session: OracleSyncConfig, params: Any, style: ParamStyle) -> None: """Test synchronous select functionality with Oracle parameter styles.""" with oracle_sync_session.provide_session() as driver: sql = """ @@ -116,6 +109,9 @@ def test_sync_select(oracle_sync_session: OracleSync, params: Any, style: ParamS results = driver.select(select_sql, select_params) assert len(results) == 1 assert results[0]["NAME"] == "test_name" + driver.execute_script( + "BEGIN EXECUTE IMMEDIATE 'DROP TABLE test_table'; EXCEPTION WHEN OTHERS THEN IF SQLCODE != -942 THEN RAISE; END IF; END;" + ) @pytest.mark.parametrize( @@ -125,7 +121,8 @@ def test_sync_select(oracle_sync_session: OracleSync, params: Any, style: ParamS pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) -def test_sync_select_value(oracle_sync_session: OracleSync, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("oracle") +def test_sync_select_value(oracle_sync_session: OracleSyncConfig, params: Any, style: ParamStyle) -> None: """Test synchronous select_value functionality with Oracle parameter styles.""" with oracle_sync_session.provide_session() as driver: sql = """ @@ -149,3 +146,40 @@ def test_sync_select_value(oracle_sync_session: OracleSync, params: Any, style: select_sql = "SELECT 'test_value' FROM dual" value = driver.select_value(select_sql) assert value == "test_value" + driver.execute_script( + "BEGIN EXECUTE IMMEDIATE 'DROP TABLE test_table'; EXCEPTION WHEN OTHERS THEN IF SQLCODE != -942 THEN RAISE; END IF; END;" + ) + + +@pytest.mark.xdist_group("oracle") +def test_sync_select_arrow(oracle_sync_session: OracleSyncConfig) -> None: + """Test synchronous select_arrow functionality.""" + with oracle_sync_session.provide_session() as driver: + sql = """ + CREATE TABLE test_table ( + id NUMBER GENERATED BY DEFAULT ON NULL AS IDENTITY PRIMARY KEY, + name VARCHAR2(50) + ) + """ + driver.execute_script(sql) + + # Insert test record using positional binds + insert_sql = "INSERT INTO test_table (name) VALUES (:1)" + driver.insert_update_delete(insert_sql, ("arrow_name",)) + + # Select and verify with select_arrow using positional binds + select_sql = "SELECT name, id FROM test_table WHERE name = :1" + arrow_table = driver.select_arrow(select_sql, ("arrow_name",)) + + assert isinstance(arrow_table, pa.Table) + assert arrow_table.num_rows == 1 + assert arrow_table.num_columns == 2 + # Oracle returns uppercase column names by default + assert arrow_table.column_names == ["NAME", "ID"] + assert arrow_table.column("NAME").to_pylist() == ["arrow_name"] + # Check ID exists and is a number (exact value depends on IDENTITY) + assert arrow_table.column("ID").to_pylist()[0] is not None + assert isinstance(arrow_table.column("ID").to_pylist()[0], (int, float)) # Oracle NUMBER maps to float/Decimal + driver.execute_script( + "BEGIN EXECUTE IMMEDIATE 'DROP TABLE test_table'; EXCEPTION WHEN OTHERS THEN IF SQLCODE != -942 THEN RAISE; END IF; END;" + ) diff --git a/tests/integration/test_adapters/test_psycopg/__init__.py b/tests/integration/test_adapters/test_psycopg/__init__.py index f1305e919..a1b06d192 100644 --- a/tests/integration/test_adapters/test_psycopg/__init__.py +++ b/tests/integration/test_adapters/test_psycopg/__init__.py @@ -1 +1,5 @@ """Integration tests for sqlspec adapters.""" + +import pytest + +pytestmark = [pytest.mark.postgres, pytest.mark.psycopg] diff --git a/tests/integration/test_adapters/test_psycopg/conftest.py b/tests/integration/test_adapters/test_psycopg/conftest.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/integration/test_adapters/test_psycopg/test_connection.py b/tests/integration/test_adapters/test_psycopg/test_connection.py index 23d13f2e3..a928736d8 100644 --- a/tests/integration/test_adapters/test_psycopg/test_connection.py +++ b/tests/integration/test_adapters/test_psycopg/test_connection.py @@ -1,16 +1,20 @@ import pytest from pytest_databases.docker.postgres import PostgresService -from sqlspec.adapters.psycopg import PsycopgAsync, PsycopgAsyncPool, PsycopgSync, PsycopgSyncPool - -pytestmark = pytest.mark.asyncio(loop_scope="session") +from sqlspec.adapters.psycopg import ( + PsycopgAsyncConfig, + PsycopgAsyncPoolConfig, + PsycopgSyncConfig, + PsycopgSyncPoolConfig, +) +@pytest.mark.xdist_group("postgres") async def test_async_connection(postgres_service: PostgresService) -> None: """Test async connection components.""" # Test direct connection - async_config = PsycopgAsync( - pool_config=PsycopgAsyncPool( + async_config = PsycopgAsyncConfig( + pool_config=PsycopgAsyncPoolConfig( conninfo=f"host={postgres_service.host} port={postgres_service.port} user={postgres_service.user} password={postgres_service.password} dbname={postgres_service.database}", ), ) @@ -24,31 +28,28 @@ async def test_async_connection(postgres_service: PostgresService) -> None: assert result == (1,) # Test connection pool - pool_config = PsycopgAsyncPool( + pool_config = PsycopgAsyncPoolConfig( conninfo=f"host={postgres_service.host} port={postgres_service.port} user={postgres_service.user} password={postgres_service.password} dbname={postgres_service.database}", min_size=1, max_size=5, ) - another_config = PsycopgAsync(pool_config=pool_config) - pool = await another_config.create_pool() - assert pool is not None - try: - async with pool.connection() as conn: - assert conn is not None - # Test basic query - async with conn.cursor() as cur: - await cur.execute("SELECT 1") - result = await cur.fetchone() - assert result == (1,) - finally: - await pool.close() + another_config = PsycopgAsyncConfig(pool_config=pool_config) + # Remove explicit pool creation and manual context management + async with another_config.provide_connection() as conn: + assert conn is not None + # Test basic query + async with conn.cursor() as cur: + await cur.execute("SELECT 1") + result = await cur.fetchone() + assert result == (1,) +@pytest.mark.xdist_group("postgres") def test_sync_connection(postgres_service: PostgresService) -> None: """Test sync connection components.""" # Test direct connection - sync_config = PsycopgSync( - pool_config=PsycopgSyncPool( + sync_config = PsycopgSyncConfig( + pool_config=PsycopgSyncPoolConfig( conninfo=f"host={postgres_service.host} port={postgres_service.port} user={postgres_service.user} password={postgres_service.password} dbname={postgres_service.database}", ), ) @@ -62,21 +63,17 @@ def test_sync_connection(postgres_service: PostgresService) -> None: assert result == (1,) # Test connection pool - pool_config = PsycopgSyncPool( + pool_config = PsycopgSyncPoolConfig( conninfo=f"postgres://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}", min_size=1, max_size=5, ) - another_config = PsycopgSync(pool_config=pool_config) - pool = another_config.create_pool() - assert pool is not None - try: - with pool.connection() as conn: - assert conn is not None - # Test basic query - with conn.cursor() as cur: - cur.execute("SELECT 1") - result = cur.fetchone() - assert result == (1,) - finally: - pool.close() + another_config = PsycopgSyncConfig(pool_config=pool_config) + # Remove explicit pool creation and manual context management + with another_config.provide_connection() as conn: + assert conn is not None + # Test basic query + with conn.cursor() as cur: + cur.execute("SELECT 1") + result = cur.fetchone() + assert result == (1,) diff --git a/tests/integration/test_adapters/test_psycopg/test_driver.py b/tests/integration/test_adapters/test_psycopg/test_driver.py index 141347f8a..319f8d373 100644 --- a/tests/integration/test_adapters/test_psycopg/test_driver.py +++ b/tests/integration/test_adapters/test_psycopg/test_driver.py @@ -2,19 +2,23 @@ from __future__ import annotations -from collections.abc import AsyncGenerator from typing import Any, Literal import pytest from pytest_databases.docker.postgres import PostgresService -from sqlspec.adapters.psycopg import PsycopgAsync, PsycopgAsyncPool, PsycopgSync, PsycopgSyncPool +from sqlspec.adapters.psycopg import ( + PsycopgAsyncConfig, + PsycopgAsyncPoolConfig, + PsycopgSyncConfig, + PsycopgSyncPoolConfig, +) ParamStyle = Literal["tuple_binds", "dict_binds"] -@pytest.fixture(scope="session") -def psycopg_sync_session(postgres_service: PostgresService) -> PsycopgSync: +@pytest.fixture +def psycopg_sync_session(postgres_service: PostgresService) -> PsycopgSyncConfig: """Create a Psycopg synchronous session. Args: @@ -23,15 +27,15 @@ def psycopg_sync_session(postgres_service: PostgresService) -> PsycopgSync: Returns: Configured Psycopg synchronous session. """ - return PsycopgSync( - pool_config=PsycopgSyncPool( + return PsycopgSyncConfig( + pool_config=PsycopgSyncPoolConfig( conninfo=f"postgres://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" ) ) -@pytest.fixture(scope="session") -def psycopg_async_session(postgres_service: PostgresService) -> PsycopgAsync: +@pytest.fixture +def psycopg_async_session(postgres_service: PostgresService) -> PsycopgAsyncConfig: """Create a Psycopg asynchronous session. Args: @@ -40,35 +44,13 @@ def psycopg_async_session(postgres_service: PostgresService) -> PsycopgAsync: Returns: Configured Psycopg asynchronous session. """ - return PsycopgAsync( - pool_config=PsycopgAsyncPool( + return PsycopgAsyncConfig( + pool_config=PsycopgAsyncPoolConfig( conninfo=f"host={postgres_service.host} port={postgres_service.port} user={postgres_service.user} password={postgres_service.password} dbname={postgres_service.database}" ) ) -@pytest.fixture(autouse=True) -async def cleanup_test_table(psycopg_async_session: PsycopgAsync) -> AsyncGenerator[None, None]: - """Clean up the test table after each test.""" - yield - async with psycopg_async_session.provide_session() as driver: - await driver.execute_script("DROP TABLE IF EXISTS test_table") - - -@pytest.fixture(autouse=True) -def cleanup_sync_table(psycopg_sync_session: PsycopgSync) -> None: - """Clean up the test table after each test.""" - with psycopg_sync_session.provide_session() as driver: - driver.execute_script("DROP TABLE IF EXISTS test_table") - - -@pytest.fixture(autouse=True) -async def cleanup_async_table(psycopg_async_session: PsycopgAsync) -> None: - """Clean up the test table after each test.""" - async with psycopg_async_session.provide_session() as driver: - await driver.execute_script("DROP TABLE IF EXISTS test_table") - - @pytest.mark.parametrize( ("params", "style"), [ @@ -76,7 +58,8 @@ async def cleanup_async_table(psycopg_async_session: PsycopgAsync) -> None: pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) -def test_sync_insert_returning(psycopg_sync_session: PsycopgSync, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("postgres") +def test_sync_insert_returning(psycopg_sync_session: PsycopgSyncConfig, params: Any, style: ParamStyle) -> None: """Test synchronous insert returning functionality with different parameter styles.""" with psycopg_sync_session.provide_session() as driver: sql = """ @@ -105,6 +88,7 @@ def test_sync_insert_returning(psycopg_sync_session: PsycopgSync, params: Any, s assert result is not None assert result["name"] == "test_name" assert result["id"] is not None + driver.execute_script("DROP TABLE IF EXISTS test_table") @pytest.mark.parametrize( @@ -114,7 +98,8 @@ def test_sync_insert_returning(psycopg_sync_session: PsycopgSync, params: Any, s pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) -def test_sync_select(psycopg_sync_session: PsycopgSync, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("postgres") +def test_sync_select(psycopg_sync_session: PsycopgSyncConfig, params: Any, style: ParamStyle) -> None: """Test synchronous select functionality with different parameter styles.""" with psycopg_sync_session.provide_session() as driver: # Create test table @@ -151,6 +136,7 @@ def test_sync_select(psycopg_sync_session: PsycopgSync, params: Any, style: Para results = driver.select(select_sql, params) assert len(results) == 1 assert results[0]["name"] == "test_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") @pytest.mark.parametrize( @@ -160,7 +146,8 @@ def test_sync_select(psycopg_sync_session: PsycopgSync, params: Any, style: Para pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) -def test_sync_select_value(psycopg_sync_session: PsycopgSync, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("postgres") +def test_sync_select_value(psycopg_sync_session: PsycopgSyncConfig, params: Any, style: ParamStyle) -> None: """Test synchronous select_value functionality with different parameter styles.""" with psycopg_sync_session.provide_session() as driver: # Create test table @@ -190,6 +177,7 @@ def test_sync_select_value(psycopg_sync_session: PsycopgSync, params: Any, style # Don't pass parameters with a literal query that has no placeholders value = driver.select_value(select_sql) assert value == "test_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") @pytest.mark.parametrize( @@ -199,7 +187,11 @@ def test_sync_select_value(psycopg_sync_session: PsycopgSync, params: Any, style pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) -async def test_async_insert_returning(psycopg_async_session: PsycopgAsync, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_async_insert_returning( + psycopg_async_session: PsycopgAsyncConfig, params: Any, style: ParamStyle +) -> None: """Test async insert returning functionality with different parameter styles.""" async with psycopg_async_session.provide_session() as driver: sql = """ @@ -228,6 +220,7 @@ async def test_async_insert_returning(psycopg_async_session: PsycopgAsync, param assert result is not None assert result["name"] == "test_name" assert result["id"] is not None + await driver.execute_script("DROP TABLE IF EXISTS test_table") @pytest.mark.parametrize( @@ -237,7 +230,9 @@ async def test_async_insert_returning(psycopg_async_session: PsycopgAsync, param pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) -async def test_async_select(psycopg_async_session: PsycopgAsync, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_async_select(psycopg_async_session: PsycopgAsyncConfig, params: Any, style: ParamStyle) -> None: """Test async select functionality with different parameter styles.""" async with psycopg_async_session.provide_session() as driver: # Create test table @@ -274,6 +269,7 @@ async def test_async_select(psycopg_async_session: PsycopgAsync, params: Any, st results = await driver.select(select_sql, params) assert len(results) == 1 assert results[0]["name"] == "test_name" + await driver.execute_script("DROP TABLE IF EXISTS test_table") @pytest.mark.parametrize( @@ -283,7 +279,9 @@ async def test_async_select(psycopg_async_session: PsycopgAsync, params: Any, st pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) -async def test_async_select_value(psycopg_async_session: PsycopgAsync, params: Any, style: ParamStyle) -> None: +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_async_select_value(psycopg_async_session: PsycopgAsyncConfig, params: Any, style: ParamStyle) -> None: """Test async select_value functionality with different parameter styles.""" async with psycopg_async_session.provide_session() as driver: # Create test table @@ -318,9 +316,12 @@ async def test_async_select_value(psycopg_async_session: PsycopgAsync, params: A # Don't pass parameters with a literal query that has no placeholders value = await driver.select_value(select_sql) assert value == "test_name" + await driver.execute_script("DROP TABLE IF EXISTS test_table") -async def test_insert(psycopg_async_session: PsycopgAsync) -> None: +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_insert(psycopg_async_session: PsycopgAsyncConfig) -> None: """Test inserting data.""" async with psycopg_async_session.provide_session() as driver: sql = """ @@ -334,9 +335,12 @@ async def test_insert(psycopg_async_session: PsycopgAsync) -> None: insert_sql = "INSERT INTO test_table (name) VALUES (%s)" row_count = await driver.insert_update_delete(insert_sql, ("test",)) assert row_count == 1 + await driver.execute_script("DROP TABLE IF EXISTS test_table") -async def test_select(psycopg_async_session: PsycopgAsync) -> None: +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_select(psycopg_async_session: PsycopgAsyncConfig) -> None: """Test selecting data.""" async with psycopg_async_session.provide_session() as driver: # Create and populate test table @@ -356,6 +360,7 @@ async def test_select(psycopg_async_session: PsycopgAsync) -> None: results = await driver.select(select_sql) assert len(results) == 1 assert results[0]["name"] == "test" + await driver.execute_script("DROP TABLE IF EXISTS test_table") @pytest.mark.parametrize( @@ -366,7 +371,8 @@ async def test_select(psycopg_async_session: PsycopgAsync) -> None: "pyformat", ], ) -def test_param_styles(psycopg_sync_session: PsycopgSync, param_style: str) -> None: +@pytest.mark.xdist_group("postgres") +def test_param_styles(psycopg_sync_session: PsycopgSyncConfig, param_style: str) -> None: """Test different parameter styles.""" with psycopg_sync_session.provide_session() as driver: # Create test table @@ -399,3 +405,4 @@ def test_param_styles(psycopg_sync_session: PsycopgSync, param_style: str) -> No results = driver.select(select_sql) assert len(results) == 1 assert results[0]["name"] == "test" + driver.execute_script("DROP TABLE IF EXISTS test_table") diff --git a/tests/integration/test_adapters/test_sqlite/__init__.py b/tests/integration/test_adapters/test_sqlite/__init__.py index f1305e919..624095cbb 100644 --- a/tests/integration/test_adapters/test_sqlite/__init__.py +++ b/tests/integration/test_adapters/test_sqlite/__init__.py @@ -1 +1,5 @@ """Integration tests for sqlspec adapters.""" + +import pytest + +pytestmark = pytest.mark.sqlite diff --git a/tests/integration/test_adapters/test_sqlite/test_connection.py b/tests/integration/test_adapters/test_sqlite/test_connection.py index 949f09597..5fe2b0050 100644 --- a/tests/integration/test_adapters/test_sqlite/test_connection.py +++ b/tests/integration/test_adapters/test_sqlite/test_connection.py @@ -1,12 +1,15 @@ """Test SQLite connection configuration.""" -from sqlspec.adapters.sqlite.config import Sqlite +import pytest +from sqlspec.adapters.sqlite.config import SqliteConfig + +@pytest.mark.xdist_group("sqlite") def test_connection() -> None: """Test connection components.""" # Test direct connection - config = Sqlite(database=":memory:") + config = SqliteConfig(database=":memory:") with config.provide_connection() as conn: assert conn is not None diff --git a/tests/integration/test_adapters/test_sqlite/test_driver.py b/tests/integration/test_adapters/test_sqlite/test_driver.py index f6c7b0530..bc73297df 100644 --- a/tests/integration/test_adapters/test_sqlite/test_driver.py +++ b/tests/integration/test_adapters/test_sqlite/test_driver.py @@ -8,7 +8,7 @@ import pytest -from sqlspec.adapters.sqlite import Sqlite, SqliteDriver +from sqlspec.adapters.sqlite import SqliteConfig, SqliteDriver from tests.fixtures.sql_utils import create_tuple_or_dict_params, format_placeholder ParamStyle = Literal["tuple_binds", "dict_binds"] @@ -21,7 +21,7 @@ def sqlite_session() -> Generator[SqliteDriver, None, None]: Returns: A configured SQLite session with a test table. """ - adapter = Sqlite() + adapter = SqliteConfig() create_table_sql = """ CREATE TABLE IF NOT EXISTS test_table ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -48,6 +48,7 @@ def cleanup_table(sqlite_session: SqliteDriver) -> None: pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("sqlite") def test_insert_update_delete_returning(sqlite_session: SqliteDriver, params: Any, style: ParamStyle) -> None: """Test insert_update_delete_returning with different parameter styles.""" # Check SQLite version for RETURNING support (3.35.0+) @@ -89,6 +90,7 @@ def test_insert_update_delete_returning(sqlite_session: SqliteDriver, params: An pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("sqlite") def test_select(sqlite_session: SqliteDriver, params: Any, style: ParamStyle) -> None: """Test select functionality with different parameter styles.""" # Insert test record @@ -114,6 +116,7 @@ def test_select(sqlite_session: SqliteDriver, params: Any, style: ParamStyle) -> pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("sqlite") def test_select_one(sqlite_session: SqliteDriver, params: Any, style: ParamStyle) -> None: """Test select_one functionality with different parameter styles.""" # Insert test record @@ -144,6 +147,7 @@ def test_select_one(sqlite_session: SqliteDriver, params: Any, style: ParamStyle pytest.param({"name": "test_name"}, {"id": 1}, "dict_binds", id="dict_binds"), ], ) +@pytest.mark.xdist_group("sqlite") def test_select_value( sqlite_session: SqliteDriver, name_params: Any, diff --git a/tests/unit/test_adapters/test_adbc/test_config.py b/tests/unit/test_adapters/test_adbc/test_config.py index 25cb64414..16ba1b86c 100644 --- a/tests/unit/test_adapters/test_adbc/test_config.py +++ b/tests/unit/test_adapters/test_adbc/test_config.py @@ -9,13 +9,13 @@ import pytest from adbc_driver_manager.dbapi import Connection -from sqlspec.adapters.adbc import Adbc +from sqlspec.adapters.adbc import AdbcConfig if TYPE_CHECKING: from collections.abc import Generator -class MockAdbc(Adbc): +class MockAdbc(AdbcConfig): """Mock implementation of ADBC for testing.""" def __init__(self, mock_connection: MagicMock | None = None, **kwargs: Any) -> None: @@ -50,13 +50,13 @@ def mock_adbc_connection() -> Generator[MagicMock, None, None]: def test_default_values() -> None: """Test default values for ADBC.""" - config = Adbc() + config = AdbcConfig() assert config.connection_config_dict == {} # pyright: ignore def test_with_all_values() -> None: """Test ADBC with all values set.""" - config = Adbc( + config = AdbcConfig( uri="localhost", driver_name="test_driver", db_kwargs={"user": "test_user", "password": "test_pass", "database": "test_db"}, @@ -64,7 +64,6 @@ def test_with_all_values() -> None: assert config.connection_config_dict == { "uri": "localhost", - "driver": "test_driver", "user": "test_user", "password": "test_pass", "database": "test_db", @@ -73,14 +72,13 @@ def test_with_all_values() -> None: def test_connection_config_dict() -> None: """Test connection_config_dict property.""" - config = Adbc( + config = AdbcConfig( uri="localhost", driver_name="test_driver", db_kwargs={"user": "test_user", "password": "test_pass", "database": "test_db"}, ) config_dict = config.connection_config_dict assert config_dict["uri"] == "localhost" - assert config_dict["driver"] == "test_driver" assert config_dict["user"] == "test_user" assert config_dict["password"] == "test_pass" assert config_dict["database"] == "test_db" diff --git a/tests/unit/test_adapters/test_aiosqlite/test_config.py b/tests/unit/test_adapters/test_aiosqlite/test_config.py index 142d26767..bdae47871 100644 --- a/tests/unit/test_adapters/test_aiosqlite/test_config.py +++ b/tests/unit/test_adapters/test_aiosqlite/test_config.py @@ -9,7 +9,7 @@ import pytest from aiosqlite import Connection -from sqlspec.adapters.aiosqlite.config import Aiosqlite +from sqlspec.adapters.aiosqlite.config import AiosqliteConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import Empty @@ -27,7 +27,7 @@ def mock_aiosqlite_connection() -> Generator[MagicMock, None, None]: def test_minimal_config() -> None: """Test minimal configuration with only required values.""" - config = Aiosqlite() + config = AiosqliteConfig() assert config.database == ":memory:" assert config.timeout is Empty assert config.detect_types is Empty @@ -40,7 +40,7 @@ def test_minimal_config() -> None: def test_full_config() -> None: """Test configuration with all values set.""" - config = Aiosqlite( + config = AiosqliteConfig( database=":memory:", timeout=5.0, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES, @@ -63,7 +63,7 @@ def test_full_config() -> None: def test_connection_config_dict() -> None: """Test connection_config_dict property.""" - config = Aiosqlite( + config = AiosqliteConfig( database=":memory:", timeout=5.0, detect_types=sqlite3.PARSE_DECLTYPES, @@ -82,7 +82,7 @@ def test_connection_config_dict() -> None: async def test_create_connection_success(mock_aiosqlite_connection: MagicMock) -> None: """Test successful connection creation.""" with patch("aiosqlite.connect", AsyncMock(return_value=mock_aiosqlite_connection)) as mock_connect: - config = Aiosqlite(database=":memory:") + config = AiosqliteConfig(database=":memory:") connection = await config.create_connection() assert connection is mock_aiosqlite_connection @@ -93,7 +93,7 @@ async def test_create_connection_success(mock_aiosqlite_connection: MagicMock) - async def test_create_connection_failure() -> None: """Test connection creation failure.""" with patch("aiosqlite.connect", AsyncMock(side_effect=Exception("Connection failed"))): - config = Aiosqlite(database=":memory:") + config = AiosqliteConfig(database=":memory:") with pytest.raises(ImproperConfigurationError, match="Could not configure the Aiosqlite connection"): await config.create_connection() @@ -102,7 +102,7 @@ async def test_create_connection_failure() -> None: async def test_provide_connection(mock_aiosqlite_connection: MagicMock) -> None: """Test provide_connection context manager.""" with patch("aiosqlite.connect", AsyncMock(return_value=mock_aiosqlite_connection)): - config = Aiosqlite(database=":memory:") + config = AiosqliteConfig(database=":memory:") async with config.provide_connection() as conn: assert conn is mock_aiosqlite_connection diff --git a/tests/unit/test_adapters/test_asyncmy/test_config.py b/tests/unit/test_adapters/test_asyncmy/test_config.py index bf02d8bae..6c9d218e7 100644 --- a/tests/unit/test_adapters/test_asyncmy/test_config.py +++ b/tests/unit/test_adapters/test_asyncmy/test_config.py @@ -8,14 +8,14 @@ import asyncmy # pyright: ignore import pytest -from sqlspec.adapters.asyncmy import Asyncmy, AsyncmyPool +from sqlspec.adapters.asyncmy import AsyncmyConfig, AsyncmyPoolConfig from sqlspec.exceptions import ImproperConfigurationError if TYPE_CHECKING: from collections.abc import Generator -class MockAsyncmy(Asyncmy): +class MockAsyncmy(AsyncmyConfig): """Mock implementation of Asyncmy for testing.""" async def create_connection(*args: Any, **kwargs: Any) -> asyncmy.Connection: # pyright: ignore @@ -29,7 +29,7 @@ def connection_config_dict(self) -> dict[str, Any]: return {} -class MockAsyncmyPool(AsyncmyPool): +class MockAsyncmyPool(AsyncmyPoolConfig): """Mock implementation of AsyncmyPool for testing.""" def __init__(self, host: str = "localhost", pool_instance: Any | None = None, **kwargs: Any) -> None: @@ -74,14 +74,14 @@ def mock_asyncmy_connection() -> Generator[MagicMock, None, None]: def test_default_values() -> None: """Test default values for asyncmy.""" - config = Asyncmy() + config = AsyncmyConfig() assert config.pool_config is None assert config.pool_instance is None # pyright: ignore def test_with_all_values() -> None: """Test asyncmy with all values set.""" - pool_config = AsyncmyPool( + pool_config = AsyncmyPoolConfig( host="localhost", port=3306, user="test_user", @@ -90,7 +90,7 @@ def test_with_all_values() -> None: minsize=1, maxsize=10, ) - config = Asyncmy(pool_config=pool_config) + config = AsyncmyConfig(pool_config=pool_config) assert config.pool_config == pool_config assert config.pool_instance is None # pyright: ignore @@ -105,14 +105,14 @@ def test_with_all_values() -> None: def test_connection_config_dict() -> None: """Test connection_config_dict property.""" - pool_config = AsyncmyPool( + pool_config = AsyncmyPoolConfig( host="localhost", port=3306, user="test_user", password="test_pass", database="test_db", ) - config = Asyncmy(pool_config=pool_config) + config = AsyncmyConfig(pool_config=pool_config) config_dict = config.connection_config_dict assert config_dict["host"] == "localhost" assert config_dict["port"] == 3306 diff --git a/tests/unit/test_adapters/test_asyncpg/test_config.py b/tests/unit/test_adapters/test_asyncpg/test_config.py index b9f59c4e6..824a6b44a 100644 --- a/tests/unit/test_adapters/test_asyncpg/test_config.py +++ b/tests/unit/test_adapters/test_asyncpg/test_config.py @@ -8,14 +8,14 @@ import asyncpg import pytest -from sqlspec.adapters.asyncpg import Asyncpg, AsyncpgPool +from sqlspec.adapters.asyncpg import AsyncpgConfig, AsyncpgPoolConfig from sqlspec.exceptions import ImproperConfigurationError if TYPE_CHECKING: from collections.abc import Generator -class MockAsyncpg(Asyncpg): +class MockAsyncpg(AsyncpgConfig): """Mock implementation of Asyncpg for testing.""" async def create_connection(*args: Any, **kwargs: Any) -> asyncpg.Connection[Any]: @@ -29,7 +29,7 @@ def connection_config_dict(self) -> dict[str, Any]: return {} -class MockAsyncpgPool(AsyncpgPool): +class MockAsyncpgPool(AsyncpgPoolConfig): """Mock implementation of AsyncpgPool for testing.""" def __init__(self, dsn: str, pool_instance: Any | None = None, **kwargs: Any) -> None: @@ -74,21 +74,21 @@ def mock_asyncpg_connection() -> Generator[MagicMock, None, None]: def test_default_values() -> None: """Test default values for Asyncpg.""" - config = Asyncpg() + config = AsyncpgConfig() assert config.pool_config is None assert config.pool_instance is None def test_with_all_values() -> None: """Test Asyncpg with all values set.""" - pool_config = AsyncpgPool( + pool_config = AsyncpgPoolConfig( dsn="postgres://test_user:test_pass@localhost:5432/test_db", min_size=1, max_size=10, max_inactive_connection_lifetime=300.0, max_queries=50000, ) - config = Asyncpg(pool_config=pool_config) + config = AsyncpgConfig(pool_config=pool_config) assert config.pool_config == pool_config assert config.pool_instance is None @@ -96,17 +96,17 @@ def test_with_all_values() -> None: def test_connection_config_dict() -> None: """Test connection_config_dict property.""" - pool_config = AsyncpgPool( + pool_config = AsyncpgPoolConfig( dsn="postgres://test_user:test_pass@localhost:5432/test_db", ) - config = Asyncpg(pool_config=pool_config) + config = AsyncpgConfig(pool_config=pool_config) config_dict = config.connection_config_dict assert config_dict["dsn"] == "postgres://test_user:test_pass@localhost:5432/test_db" def test_pool_config_dict_with_pool_config() -> None: """Test pool_config_dict with pool configuration.""" - pool_config = AsyncpgPool( + pool_config = AsyncpgPoolConfig( dsn="postgres://test_user:test_pass@localhost:5432/test_db", min_size=1, max_size=10, diff --git a/tests/unit/test_adapters/test_duckdb/test_config.py b/tests/unit/test_adapters/test_duckdb/test_config.py index 919da1362..ec9573686 100644 --- a/tests/unit/test_adapters/test_duckdb/test_config.py +++ b/tests/unit/test_adapters/test_duckdb/test_config.py @@ -8,7 +8,7 @@ import duckdb import pytest -from sqlspec.adapters.duckdb.config import DuckDB, ExtensionConfig, SecretConfig +from sqlspec.adapters.duckdb.config import DuckDBConfig, ExtensionConfig, SecretConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import Empty @@ -16,7 +16,7 @@ from collections.abc import Generator -class MockDuckDB(DuckDB): +class MockDuckDB(DuckDBConfig): """Mock implementation of DuckDB for testing.""" def __init__(self, *args: Any, connection: MagicMock | None = None, **kwargs: Any) -> None: @@ -45,7 +45,7 @@ def mock_duckdb_connection() -> Generator[MagicMock, None, None]: def test_default_values() -> None: """Test default values for DuckDB.""" - config = DuckDB() + config = DuckDBConfig() assert config.database == ":memory:" assert config.read_only is Empty assert config.config == {} @@ -66,7 +66,7 @@ def on_connection_create(conn: duckdb.DuckDBPyConnection) -> None: extensions: list[ExtensionConfig] = [{"name": "test_ext"}] secrets: list[SecretConfig] = [{"name": "test_secret", "secret_type": "s3", "value": {"key": "value"}}] - config = DuckDB( + config = DuckDBConfig( database="test.db", read_only=True, config={"setting": "value"}, @@ -91,7 +91,7 @@ def on_connection_create(conn: duckdb.DuckDBPyConnection) -> None: def test_connection_config_dict() -> None: """Test connection_config_dict property.""" - config = DuckDB( + config = DuckDBConfig( database="test.db", read_only=True, config={"setting": "value"}, @@ -116,7 +116,7 @@ def test_create_connection() -> None: def test_create_connection_error() -> None: """Test create_connection method with error.""" - config = DuckDB( + config = DuckDBConfig( database="test.db", read_only=True, config={"setting": "value"}, diff --git a/tests/unit/test_adapters/test_oracledb/test_async_config.py b/tests/unit/test_adapters/test_oracledb/test_async_config.py index b507469e6..50a4172fd 100644 --- a/tests/unit/test_adapters/test_oracledb/test_async_config.py +++ b/tests/unit/test_adapters/test_oracledb/test_async_config.py @@ -8,14 +8,14 @@ import pytest from oracledb import AsyncConnection, AsyncConnectionPool -from sqlspec.adapters.oracledb import OracleAsync, OracleAsyncPool +from sqlspec.adapters.oracledb import OracleAsyncConfig, OracleAsyncPoolConfig from sqlspec.exceptions import ImproperConfigurationError if TYPE_CHECKING: from collections.abc import Generator -class MockOracleAsync(OracleAsync): +class MockOracleAsync(OracleAsyncConfig): """Mock implementation of OracleAsync for testing.""" async def create_connection(*args: Any, **kwargs: Any) -> AsyncConnection: @@ -53,7 +53,7 @@ def mock_oracle_async_connection() -> Generator[MagicMock, None, None]: def test_default_values() -> None: """Test default values for OracleAsync.""" - config = OracleAsync() + config = OracleAsyncConfig() assert config.pool_config is None assert config.pool_instance is None @@ -61,10 +61,10 @@ def test_default_values() -> None: def test_with_all_values() -> None: """Test OracleAsync with all values set.""" mock_pool = MagicMock(spec=AsyncConnectionPool) - pool_config = OracleAsyncPool( + pool_config = OracleAsyncPoolConfig( pool=mock_pool, ) - config = OracleAsync( + config = OracleAsyncConfig( pool_config=pool_config, ) @@ -75,10 +75,10 @@ def test_with_all_values() -> None: def test_connection_config_dict() -> None: """Test connection_config_dict property.""" mock_pool = MagicMock(spec=AsyncConnectionPool) - pool_config = OracleAsyncPool( + pool_config = OracleAsyncPoolConfig( pool=mock_pool, ) - config = OracleAsync( + config = OracleAsyncConfig( pool_config=pool_config, ) config_dict = config.connection_config_dict @@ -89,7 +89,7 @@ def test_connection_config_dict() -> None: def test_pool_config_dict_with_pool_config() -> None: """Test pool_config_dict with pool configuration.""" mock_pool = MagicMock(spec=AsyncConnectionPool) - pool_config = OracleAsyncPool( + pool_config = OracleAsyncPoolConfig( pool=mock_pool, ) config = MockOracleAsync(pool_config=pool_config) diff --git a/tests/unit/test_adapters/test_oracledb/test_sync_config.py b/tests/unit/test_adapters/test_oracledb/test_sync_config.py index 8afadc0f8..fe18254e9 100644 --- a/tests/unit/test_adapters/test_oracledb/test_sync_config.py +++ b/tests/unit/test_adapters/test_oracledb/test_sync_config.py @@ -8,14 +8,14 @@ import pytest from oracledb import Connection, ConnectionPool -from sqlspec.adapters.oracledb.config import OracleSync, OracleSyncPool +from sqlspec.adapters.oracledb.config import OracleSyncConfig, OracleSyncPoolConfig from sqlspec.exceptions import ImproperConfigurationError if TYPE_CHECKING: from collections.abc import Generator -class MockOracleSync(OracleSync): +class MockOracleSync(OracleSyncConfig): """Mock implementation of OracleSync for testing.""" def create_connection(*args: Any, **kwargs: Any) -> Connection: @@ -50,7 +50,7 @@ def mock_oracle_connection() -> Generator[MagicMock, None, None]: def test_default_values() -> None: """Test default values for OracleSync.""" - config = OracleSync() + config = OracleSyncConfig() assert config.pool_config is None assert config.pool_instance is None @@ -58,10 +58,10 @@ def test_default_values() -> None: def test_with_all_values() -> None: """Test OracleSync with all values set.""" mock_pool = MagicMock(spec=ConnectionPool) - pool_config = OracleSyncPool( + pool_config = OracleSyncPoolConfig( pool=mock_pool, ) - config = OracleSync( + config = OracleSyncConfig( pool_config=pool_config, ) @@ -72,10 +72,10 @@ def test_with_all_values() -> None: def test_connection_config_dict() -> None: """Test connection_config_dict property.""" mock_pool = MagicMock(spec=ConnectionPool) - pool_config = OracleSyncPool( + pool_config = OracleSyncPoolConfig( pool=mock_pool, ) - config = OracleSync( + config = OracleSyncConfig( pool_config=pool_config, ) config_dict = config.connection_config_dict @@ -86,7 +86,7 @@ def test_connection_config_dict() -> None: def test_pool_config_dict_with_pool_config() -> None: """Test pool_config_dict with pool configuration.""" mock_pool = MagicMock(spec=ConnectionPool) - pool_config = OracleSyncPool( + pool_config = OracleSyncPoolConfig( pool=mock_pool, ) config = MockOracleSync(pool_config=pool_config) diff --git a/tests/unit/test_adapters/test_psycopg/test_async_config.py b/tests/unit/test_adapters/test_psycopg/test_async_config.py index a300b2288..6b7f5e6a9 100644 --- a/tests/unit/test_adapters/test_psycopg/test_async_config.py +++ b/tests/unit/test_adapters/test_psycopg/test_async_config.py @@ -9,7 +9,7 @@ from psycopg import AsyncConnection from psycopg_pool import AsyncConnectionPool -from sqlspec.adapters.psycopg.config import PsycopgAsync, PsycopgAsyncPool +from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig, PsycopgAsyncPoolConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import Empty @@ -17,7 +17,7 @@ from collections.abc import Generator -class MockPsycopgAsync(PsycopgAsync): +class MockPsycopgAsync(PsycopgAsyncConfig): """Mock implementation of PsycopgAsync for testing.""" async def create_connection(*args: Any, **kwargs: Any) -> AsyncConnection: @@ -58,7 +58,7 @@ def mock_psycopg_async_connection() -> Generator[MagicMock, None, None]: def test_default_values() -> None: """Test default values for PsycopgAsyncPool.""" - config = PsycopgAsyncPool() + config = PsycopgAsyncPoolConfig() assert config.conninfo is Empty assert config.kwargs is Empty assert config.min_size is Empty @@ -80,7 +80,7 @@ def configure_connection(conn: AsyncConnection) -> None: """Configure connection.""" pass - config = PsycopgAsyncPool( + config = PsycopgAsyncPoolConfig( conninfo="postgresql://user:pass@localhost:5432/db", kwargs={"application_name": "test"}, min_size=1, @@ -111,7 +111,7 @@ def configure_connection(conn: AsyncConnection) -> None: def test_pool_config_dict_with_pool_config() -> None: """Test pool_config_dict with pool configuration.""" - pool_config = PsycopgAsyncPool( + pool_config = PsycopgAsyncPoolConfig( conninfo="postgresql://user:pass@localhost:5432/db", min_size=1, max_size=10, diff --git a/tests/unit/test_adapters/test_psycopg/test_sync_config.py b/tests/unit/test_adapters/test_psycopg/test_sync_config.py index fd1a4b3c5..9a9dc8207 100644 --- a/tests/unit/test_adapters/test_psycopg/test_sync_config.py +++ b/tests/unit/test_adapters/test_psycopg/test_sync_config.py @@ -9,7 +9,7 @@ from psycopg import Connection from psycopg_pool import ConnectionPool -from sqlspec.adapters.psycopg.config import PsycopgSync, PsycopgSyncPool +from sqlspec.adapters.psycopg.config import PsycopgSyncConfig, PsycopgSyncPoolConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import Empty @@ -17,7 +17,7 @@ from collections.abc import Generator -class MockPsycopgSync(PsycopgSync): +class MockPsycopgSync(PsycopgSyncConfig): """Mock implementation of PsycopgSync for testing.""" def create_connection(*args: Any, **kwargs: Any) -> Connection: @@ -56,7 +56,7 @@ def mock_psycopg_connection() -> Generator[MagicMock, None, None]: def test_default_values() -> None: """Test default values for PsycopgSyncPool.""" - config = PsycopgSyncPool() + config = PsycopgSyncPoolConfig() assert config.conninfo is Empty assert config.kwargs is Empty assert config.min_size is Empty @@ -78,7 +78,7 @@ def configure_connection(conn: Connection) -> None: """Configure connection.""" pass - config = PsycopgSyncPool( + config = PsycopgSyncPoolConfig( conninfo="postgresql://user:pass@localhost:5432/db", kwargs={"application_name": "test"}, min_size=1, @@ -109,7 +109,7 @@ def configure_connection(conn: Connection) -> None: def test_pool_config_dict_with_pool_config() -> None: """Test pool_config_dict with pool configuration.""" - pool_config = PsycopgSyncPool( + pool_config = PsycopgSyncPoolConfig( conninfo="postgresql://user:pass@localhost:5432/db", min_size=1, max_size=10, diff --git a/tests/unit/test_adapters/test_sqlite/test_config.py b/tests/unit/test_adapters/test_sqlite/test_config.py index bc0fb9c17..cc08a2e81 100644 --- a/tests/unit/test_adapters/test_sqlite/test_config.py +++ b/tests/unit/test_adapters/test_sqlite/test_config.py @@ -8,7 +8,7 @@ import pytest -from sqlspec.adapters.sqlite.config import Sqlite +from sqlspec.adapters.sqlite.config import SqliteConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import Empty @@ -27,7 +27,7 @@ def mock_sqlite_connection() -> Generator[MagicMock, None, None]: def test_default_values() -> None: """Test default values for Sqlite.""" - config = Sqlite() + config = SqliteConfig() assert config.database == ":memory:" assert config.timeout is Empty assert config.detect_types is Empty @@ -40,7 +40,7 @@ def test_default_values() -> None: def test_with_all_values() -> None: """Test Sqlite with all values set.""" - config = Sqlite( + config = SqliteConfig( database="test.db", timeout=30.0, detect_types=1, @@ -62,14 +62,14 @@ def test_with_all_values() -> None: def test_connection_config_dict() -> None: """Test connection_config_dict property.""" - config = Sqlite(database="test.db", timeout=30.0) + config = SqliteConfig(database="test.db", timeout=30.0) config_dict = config.connection_config_dict assert config_dict == {"database": "test.db", "timeout": 30.0} def test_create_connection(mock_sqlite_connection: MagicMock) -> None: """Test create_connection method.""" - config = Sqlite(database="test.db") + config = SqliteConfig(database="test.db") connection = config.create_connection() assert connection is mock_sqlite_connection @@ -77,13 +77,13 @@ def test_create_connection(mock_sqlite_connection: MagicMock) -> None: def test_create_connection_error() -> None: """Test create_connection raises error on failure.""" with patch("sqlite3.connect", side_effect=Exception("Test error")): - config = Sqlite(database="test.db") + config = SqliteConfig(database="test.db") with pytest.raises(ImproperConfigurationError, match="Could not configure the SQLite connection"): config.create_connection() def test_provide_connection(mock_sqlite_connection: MagicMock) -> None: """Test provide_connection context manager.""" - config = Sqlite(database="test.db") + config = SqliteConfig(database="test.db") with config.provide_connection() as connection: assert connection is mock_sqlite_connection diff --git a/tests/unit/test_base.py b/tests/unit/test_base.py index 026615af0..11cf6d78f 100644 --- a/tests/unit/test_base.py +++ b/tests/unit/test_base.py @@ -72,6 +72,14 @@ def _provide_pool() -> Generator[MockPool, None, None]: return _provide_pool() + @contextmanager + def provide_session(self, *args: Any, **kwargs: Any) -> Generator[MockConnection, None, None]: + connection = self.create_connection() + try: + yield connection + finally: + connection.close() + class MockNonPoolConfig(NoPoolSyncConfig[MockConnection, Any]): """Mock database configuration that doesn't support pooling.""" @@ -90,6 +98,14 @@ def provide_connection(self, *args: Any, **kwargs: Any) -> Generator[MockConnect def close_pool(self) -> None: pass + @contextmanager + def provide_session(self, *args: Any, **kwargs: Any) -> Generator[MockConnection, None, None]: + connection = self.create_connection() + try: + yield connection + finally: + connection.close() + @property def connection_config_dict(self) -> dict[str, Any]: return {"host": "localhost", "port": 5432} @@ -112,6 +128,14 @@ async def provide_connection(self, *args: Any, **kwargs: Any) -> AsyncGenerator[ async def close_pool(self) -> None: pass + @asynccontextmanager + async def provide_session(self, *args: Any, **kwargs: Any) -> AsyncGenerator[MockAsyncConnection, None]: + connection = self.create_connection() + try: + yield connection + finally: + await connection.close() + @property def connection_config_dict(self) -> dict[str, Any]: return {"host": "localhost", "port": 5432}