diff --git a/sqlspec/adapters/adbc/__init__.py b/sqlspec/adapters/adbc/__init__.py index f7432e597..e7c0b90a1 100644 --- a/sqlspec/adapters/adbc/__init__.py +++ b/sqlspec/adapters/adbc/__init__.py @@ -1,7 +1,8 @@ from sqlspec.adapters.adbc.config import AdbcConfig -from sqlspec.adapters.adbc.driver import AdbcDriver +from sqlspec.adapters.adbc.driver import AdbcConnection, AdbcDriver __all__ = ( "AdbcConfig", + "AdbcConnection", "AdbcDriver", ) diff --git a/sqlspec/adapters/adbc/config.py b/sqlspec/adapters/adbc/config.py index 86fcb738b..36c69a6a8 100644 --- a/sqlspec/adapters/adbc/config.py +++ b/sqlspec/adapters/adbc/config.py @@ -2,9 +2,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast -from adbc_driver_manager.dbapi import Connection - -from sqlspec.adapters.adbc.driver import AdbcDriver +from sqlspec.adapters.adbc.driver import AdbcConnection, AdbcDriver from sqlspec.base import NoPoolSyncConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import Empty, EmptyType @@ -18,7 +16,7 @@ @dataclass -class AdbcConfig(NoPoolSyncConfig["Connection", "AdbcDriver"]): +class AdbcConfig(NoPoolSyncConfig["AdbcConnection", "AdbcDriver"]): """Configuration for ADBC connections. This class provides configuration options for ADBC database connections using the @@ -33,20 +31,16 @@ class AdbcConfig(NoPoolSyncConfig["Connection", "AdbcDriver"]): """Additional database-specific connection parameters""" conn_kwargs: "Optional[dict[str, Any]]" = None """Additional database-specific connection parameters""" - connection_type: "type[Connection]" = field(init=False, default_factory=lambda: Connection) + connection_type: "type[AdbcConnection]" = field(init=False, default_factory=lambda: AdbcConnection) """Type of the connection object""" driver_type: "type[AdbcDriver]" = field(init=False, default_factory=lambda: AdbcDriver) # type: ignore[type-abstract,unused-ignore] """Type of the driver object""" - pool_instance: None = field(init=False, default=None) + pool_instance: None = field(init=False, default=None, hash=False) """No connection pool is used for ADBC connections""" - _is_in_memory: bool = field(init=False, default=False) - """Flag indicating if the connection is for an in-memory database""" def _set_adbc(self) -> str: # noqa: PLR0912 """Identify the driver type based on the URI (if provided) or preset driver name. - Also sets the `_is_in_memory` flag for specific in-memory URIs. - Raises: ImproperConfigurationError: If the driver name is not recognized or supported. @@ -143,7 +137,7 @@ def connection_config_dict(self) -> "dict[str, Any]": config["conn_kwargs"] = conn_kwargs return config - def _get_connect_func(self) -> "Callable[..., Connection]": + def _get_connect_func(self) -> "Callable[..., AdbcConnection]": self._set_adbc() driver_path = cast("str", self.driver_name) try: @@ -166,7 +160,7 @@ def _get_connect_func(self) -> "Callable[..., Connection]": raise ImproperConfigurationError(msg) return connect_func # type: ignore[no-any-return] - def create_connection(self) -> "Connection": + def create_connection(self) -> "AdbcConnection": """Create and return a new database connection using the specific driver. Returns: @@ -189,7 +183,7 @@ def create_connection(self) -> "Connection": raise ImproperConfigurationError(msg) from e @contextmanager - def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[Connection, None, None]": + def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[AdbcConnection, None, None]": """Create and provide a database connection using the specific driver. Yields: diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index 1562bea8c..8b55e1d6e 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType, T -__all__ = ("AdbcDriver",) +__all__ = ("AdbcConnection", "AdbcDriver") logger = logging.getLogger("sqlspec") @@ -33,24 +33,26 @@ re.VERBOSE | re.DOTALL, ) +AdbcConnection = Connection + class AdbcDriver( - SyncArrowBulkOperationsMixin["Connection"], - SQLTranslatorMixin["Connection"], - SyncDriverAdapterProtocol["Connection"], + SyncArrowBulkOperationsMixin["AdbcConnection"], + SQLTranslatorMixin["AdbcConnection"], + SyncDriverAdapterProtocol["AdbcConnection"], ): """ADBC Sync Driver Adapter.""" - connection: Connection + connection: AdbcConnection __supports_arrow__: ClassVar[bool] = True - def __init__(self, connection: "Connection") -> None: + def __init__(self, connection: "AdbcConnection") -> None: """Initialize the ADBC driver adapter.""" self.connection = connection self.dialect = self._get_dialect(connection) @staticmethod - def _get_dialect(connection: "Connection") -> str: # noqa: PLR0911 + def _get_dialect(connection: "AdbcConnection") -> str: # noqa: PLR0911 """Get the database dialect based on the driver name. Args: @@ -75,11 +77,11 @@ def _get_dialect(connection: "Connection") -> str: # noqa: PLR0911 return "postgres" # default to postgresql dialect @staticmethod - def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "Cursor": + def _cursor(connection: "AdbcConnection", *args: Any, **kwargs: Any) -> "Cursor": return connection.cursor(*args, **kwargs) @contextmanager - def _with_cursor(self, connection: "Connection") -> Generator["Cursor", None, None]: + def _with_cursor(self, connection: "AdbcConnection") -> Generator["Cursor", None, None]: cursor = self._cursor(connection) try: yield cursor @@ -172,7 +174,7 @@ def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AdbcConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Sequence[dict[str, Any]]": ... @@ -183,7 +185,7 @@ def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AdbcConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Sequence[ModelDTOT]": ... @@ -193,7 +195,7 @@ def select( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AdbcConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": @@ -223,7 +225,7 @@ def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AdbcConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -234,7 +236,7 @@ def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AdbcConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -244,7 +246,7 @@ def select_one( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AdbcConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": @@ -271,7 +273,7 @@ def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AdbcConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[dict[str, Any]]": ... @@ -282,7 +284,7 @@ def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AdbcConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Optional[ModelDTOT]": ... @@ -292,7 +294,7 @@ def select_one_or_none( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AdbcConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": @@ -320,7 +322,7 @@ def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AdbcConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Any": ... @@ -331,7 +333,7 @@ def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AdbcConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "T": ... @@ -341,7 +343,7 @@ def select_value( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AdbcConnection"] = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Union[T, Any]": @@ -367,7 +369,7 @@ def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AdbcConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[Any]": ... @@ -378,7 +380,7 @@ def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AdbcConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "Optional[T]": ... @@ -388,7 +390,7 @@ def select_value_or_none( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AdbcConnection"] = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Optional[Union[T, Any]]": @@ -414,7 +416,7 @@ def insert_update_delete( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AdbcConnection"] = None, **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -436,7 +438,7 @@ def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AdbcConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -447,7 +449,7 @@ def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AdbcConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -457,7 +459,7 @@ def insert_update_delete_returning( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AdbcConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": @@ -490,7 +492,7 @@ def execute_script( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AdbcConnection"] = None, **kwargs: Any, ) -> str: """Execute a script. @@ -513,7 +515,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AdbcConnection]" = None, **kwargs: Any, ) -> "ArrowTable": """Execute a SQL query and return results as an Apache Arrow Table. diff --git a/sqlspec/adapters/aiosqlite/__init__.py b/sqlspec/adapters/aiosqlite/__init__.py index 18c21a803..a959a45eb 100644 --- a/sqlspec/adapters/aiosqlite/__init__.py +++ b/sqlspec/adapters/aiosqlite/__init__.py @@ -1,7 +1,8 @@ from sqlspec.adapters.aiosqlite.config import AiosqliteConfig -from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver +from sqlspec.adapters.aiosqlite.driver import AiosqliteConnection, AiosqliteDriver __all__ = ( "AiosqliteConfig", + "AiosqliteConnection", "AiosqliteDriver", ) diff --git a/sqlspec/adapters/aiosqlite/config.py b/sqlspec/adapters/aiosqlite/config.py index 87512c768..b80eba628 100644 --- a/sqlspec/adapters/aiosqlite/config.py +++ b/sqlspec/adapters/aiosqlite/config.py @@ -2,16 +2,15 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional, Union -from aiosqlite import Connection +import aiosqlite -from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver +from sqlspec.adapters.aiosqlite.driver import AiosqliteConnection, AiosqliteDriver from sqlspec.base import NoPoolAsyncConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import Empty, EmptyType, dataclass_to_dict if TYPE_CHECKING: from collections.abc import AsyncGenerator - from sqlite3 import Connection as SQLite3Connection from typing import Literal @@ -19,7 +18,7 @@ @dataclass -class AiosqliteConfig(NoPoolAsyncConfig["Connection", "AiosqliteDriver"]): +class AiosqliteConfig(NoPoolAsyncConfig["AiosqliteConnection", "AiosqliteDriver"]): """Configuration for Aiosqlite database connections. This class provides configuration options for Aiosqlite database connections, wrapping all parameters @@ -38,13 +37,11 @@ class AiosqliteConfig(NoPoolAsyncConfig["Connection", "AiosqliteDriver"]): """The isolation_level of the connection. This can be None for autocommit mode or one of "DEFERRED", "IMMEDIATE" or "EXCLUSIVE".""" check_same_thread: "Union[bool, EmptyType]" = field(default=Empty) """If True (default), ProgrammingError is raised if the database connection is used by a thread other than the one that created it. If False, the connection may be shared across multiple threads.""" - factory: "Union[type[SQLite3Connection], EmptyType]" = field(default=Empty) - """A custom Connection class factory. If given, must be a callable that returns a Connection instance.""" cached_statements: "Union[int, EmptyType]" = field(default=Empty) """The number of statements that SQLite will cache for this connection. The default is 128.""" uri: "Union[bool, EmptyType]" = field(default=Empty) """If set to True, database is interpreted as a URI with supported options.""" - connection_type: "type[Connection]" = field(init=False, default_factory=lambda: Connection) + connection_type: "type[AiosqliteConnection]" = field(init=False, default_factory=lambda: AiosqliteConnection) """Type of the connection object""" driver_type: "type[AiosqliteDriver]" = field(init=False, default_factory=lambda: AiosqliteDriver) # type: ignore[type-abstract,unused-ignore] """Type of the driver object""" @@ -57,10 +54,13 @@ def connection_config_dict(self) -> "dict[str, Any]": A string keyed dict of config kwargs for the aiosqlite.connect() function. """ return dataclass_to_dict( - self, exclude_empty=True, convert_nested=False, exclude={"pool_instance", "connection_type", "driver_type"} + self, + exclude_empty=True, + convert_nested=False, + exclude={"pool_instance", "connection_type", "driver_type"}, ) - async def create_connection(self) -> "Connection": + async def create_connection(self) -> "AiosqliteConnection": """Create and return a new database connection. Returns: @@ -69,8 +69,6 @@ async def create_connection(self) -> "Connection": Raises: ImproperConfigurationError: If the connection could not be established. """ - import aiosqlite - try: return await aiosqlite.connect(**self.connection_config_dict) except Exception as e: @@ -78,7 +76,7 @@ async def create_connection(self) -> "Connection": raise ImproperConfigurationError(msg) from e @asynccontextmanager - async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[Connection, None]": + async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[AiosqliteConnection, None]": """Create and provide a database connection. Yields: diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index 860187505..fb6e34e52 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -1,37 +1,38 @@ from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload +import aiosqlite + from sqlspec.base import AsyncDriverAdapterProtocol from sqlspec.mixins import SQLTranslatorMixin if TYPE_CHECKING: from collections.abc import AsyncGenerator, Sequence - from aiosqlite import Connection, Cursor - from sqlspec.typing import ModelDTOT, StatementParameterType, T -__all__ = ("AiosqliteDriver",) +__all__ = ("AiosqliteConnection", "AiosqliteDriver") +AiosqliteConnection = aiosqlite.Connection class AiosqliteDriver( - SQLTranslatorMixin["Connection"], - AsyncDriverAdapterProtocol["Connection"], + SQLTranslatorMixin["AiosqliteConnection"], + AsyncDriverAdapterProtocol["AiosqliteConnection"], ): """SQLite Async Driver Adapter.""" - connection: "Connection" + connection: "AiosqliteConnection" dialect: str = "sqlite" - def __init__(self, connection: "Connection") -> None: + def __init__(self, connection: "AiosqliteConnection") -> None: self.connection = connection @staticmethod - async def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "Cursor": + async def _cursor(connection: "AiosqliteConnection", *args: Any, **kwargs: Any) -> "aiosqlite.Cursor": return await connection.cursor(*args, **kwargs) @asynccontextmanager - async def _with_cursor(self, connection: "Connection") -> "AsyncGenerator[Cursor, None]": + async def _with_cursor(self, connection: "AiosqliteConnection") -> "AsyncGenerator[aiosqlite.Cursor, None]": cursor = await self._cursor(connection) try: yield cursor @@ -46,7 +47,7 @@ async def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AiosqliteConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Sequence[dict[str, Any]]": ... @@ -57,7 +58,7 @@ async def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AiosqliteConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Sequence[ModelDTOT]": ... @@ -67,7 +68,7 @@ async def select( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AiosqliteConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": @@ -95,7 +96,7 @@ async def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AiosqliteConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -106,7 +107,7 @@ async def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AiosqliteConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -116,7 +117,7 @@ async def select_one( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AiosqliteConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": @@ -143,7 +144,7 @@ async def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AiosqliteConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[dict[str, Any]]": ... @@ -154,7 +155,7 @@ async def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AiosqliteConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Optional[ModelDTOT]": ... @@ -164,7 +165,7 @@ async def select_one_or_none( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AiosqliteConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": @@ -192,7 +193,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AiosqliteConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Any": ... @@ -203,7 +204,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AiosqliteConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "T": ... @@ -213,7 +214,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AiosqliteConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Union[T, Any]": @@ -239,7 +240,7 @@ async def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AiosqliteConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[Any]": ... @@ -250,7 +251,7 @@ async def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AiosqliteConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "Optional[T]": ... @@ -260,7 +261,7 @@ async def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AiosqliteConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Optional[Union[T, Any]]": @@ -286,7 +287,7 @@ async def insert_update_delete( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AiosqliteConnection"] = None, **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -308,7 +309,7 @@ async def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AiosqliteConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -319,7 +320,7 @@ async def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AiosqliteConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -329,7 +330,7 @@ async def insert_update_delete_returning( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AiosqliteConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": @@ -357,7 +358,7 @@ async def execute_script( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AiosqliteConnection"] = None, **kwargs: Any, ) -> str: """Execute a script. @@ -378,7 +379,7 @@ async def execute_script_returning( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AiosqliteConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": diff --git a/sqlspec/adapters/asyncmy/__init__.py b/sqlspec/adapters/asyncmy/__init__.py index 6774c059d..40d7b74f9 100644 --- a/sqlspec/adapters/asyncmy/__init__.py +++ b/sqlspec/adapters/asyncmy/__init__.py @@ -1,8 +1,9 @@ from sqlspec.adapters.asyncmy.config import AsyncmyConfig, AsyncmyPoolConfig -from sqlspec.adapters.asyncmy.driver import AsyncmyDriver # type: ignore[attr-defined] +from sqlspec.adapters.asyncmy.driver import AsyncmyConnection, AsyncmyDriver # type: ignore[attr-defined] __all__ = ( "AsyncmyConfig", + "AsyncmyConnection", "AsyncmyDriver", "AsyncmyPoolConfig", ) diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index c52bdd390..f291f35b3 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -3,37 +3,36 @@ from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload +from asyncmy import Connection + from sqlspec.base import AsyncDriverAdapterProtocol from sqlspec.mixins import SQLTranslatorMixin if TYPE_CHECKING: - from asyncmy import Connection from asyncmy.cursors import Cursor from sqlspec.typing import ModelDTOT, StatementParameterType, T __all__ = ("AsyncmyDriver",) +AsyncmyConnection = Connection + class AsyncmyDriver( - SQLTranslatorMixin["Connection"], - AsyncDriverAdapterProtocol["Connection"], + SQLTranslatorMixin["AsyncmyConnection"], + AsyncDriverAdapterProtocol["AsyncmyConnection"], ): """Asyncmy MySQL/MariaDB Driver Adapter.""" - connection: "Connection" + connection: "AsyncmyConnection" dialect: str = "mysql" - def __init__(self, connection: "Connection") -> None: + def __init__(self, connection: "AsyncmyConnection") -> None: self.connection = connection - @staticmethod - async def _cursor(connection: "Connection") -> "Cursor": - return await connection.cursor() - @staticmethod @asynccontextmanager - async def _with_cursor(connection: "Connection") -> AsyncGenerator["Cursor", None]: + async def _with_cursor(connection: "AsyncmyConnection") -> AsyncGenerator["Cursor", None]: cursor = connection.cursor() try: yield cursor @@ -48,7 +47,7 @@ async def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AsyncmyConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Sequence[dict[str, Any]]": ... @@ -59,7 +58,7 @@ async def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AsyncmyConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Sequence[ModelDTOT]": ... @@ -69,7 +68,7 @@ async def select( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AsyncmyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": @@ -97,7 +96,7 @@ async def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AsyncmyConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -108,7 +107,7 @@ async def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AsyncmyConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -118,7 +117,7 @@ async def select_one( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AsyncmyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": @@ -145,7 +144,7 @@ async def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AsyncmyConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[dict[str, Any]]": ... @@ -156,7 +155,7 @@ async def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AsyncmyConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Optional[ModelDTOT]": ... @@ -166,7 +165,7 @@ async def select_one_or_none( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AsyncmyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": @@ -194,7 +193,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AsyncmyConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Any": ... @@ -205,7 +204,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AsyncmyConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "T": ... @@ -215,7 +214,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AsyncmyConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Union[T, Any]": @@ -244,7 +243,7 @@ async def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AsyncmyConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[Any]": ... @@ -255,7 +254,7 @@ async def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AsyncmyConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "Optional[T]": ... @@ -265,7 +264,7 @@ async def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AsyncmyConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Optional[Union[T, Any]]": @@ -295,7 +294,7 @@ async def insert_update_delete( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AsyncmyConnection"] = None, **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -317,7 +316,7 @@ async def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AsyncmyConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -328,7 +327,7 @@ async def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[AsyncmyConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -338,7 +337,7 @@ async def insert_update_delete_returning( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AsyncmyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": @@ -367,7 +366,7 @@ async def execute_script( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["AsyncmyConnection"] = None, **kwargs: Any, ) -> str: """Execute a script. diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index 3710bcead..0acc9c5cd 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -190,9 +190,7 @@ async def create_connection(self) -> "AsyncpgConnection": raise ImproperConfigurationError(msg) from e @asynccontextmanager - async def provide_connection( - self, *args: "Any", **kwargs: "Any" - ) -> "AsyncGenerator[PoolConnectionProxy[Any], None]": # pyright: ignore[reportMissingTypeArgument,reportUnknownParameterType] + async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[AsyncpgConnection, None]": # pyright: ignore[reportMissingTypeArgument,reportUnknownParameterType] """Create a connection instance. Yields: diff --git a/sqlspec/adapters/duckdb/__init__.py b/sqlspec/adapters/duckdb/__init__.py index 39dfdae57..f1e613c19 100644 --- a/sqlspec/adapters/duckdb/__init__.py +++ b/sqlspec/adapters/duckdb/__init__.py @@ -1,7 +1,8 @@ from sqlspec.adapters.duckdb.config import DuckDBConfig -from sqlspec.adapters.duckdb.driver import DuckDBDriver +from sqlspec.adapters.duckdb.driver import DuckDBConnection, DuckDBDriver __all__ = ( "DuckDBConfig", + "DuckDBConnection", "DuckDBDriver", ) diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index 68270901b..ce6aaab5d 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -2,10 +2,9 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast -from duckdb import DuckDBPyConnection from typing_extensions import Literal, NotRequired, TypedDict -from sqlspec.adapters.duckdb.driver import DuckDBDriver +from sqlspec.adapters.duckdb.driver import DuckDBConnection, DuckDBDriver from sqlspec.base import NoPoolSyncConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import Empty, EmptyType, dataclass_to_dict @@ -69,7 +68,7 @@ class SecretConfig(TypedDict): @dataclass -class DuckDBConfig(NoPoolSyncConfig["DuckDBPyConnection", "DuckDBDriver"]): +class DuckDBConfig(NoPoolSyncConfig["DuckDBConnection", "DuckDBDriver"]): """Configuration for DuckDB database connections. This class provides configuration options for DuckDB database connections, wrapping all parameters @@ -96,10 +95,10 @@ class DuckDBConfig(NoPoolSyncConfig["DuckDBPyConnection", "DuckDBDriver"]): """A dictionary of secrets to store in the connection for later retrieval.""" auto_update_extensions: "bool" = False """Whether to automatically update on connection creation""" - on_connection_create: "Optional[Callable[[DuckDBPyConnection], Optional[DuckDBPyConnection]]]" = None + on_connection_create: "Optional[Callable[[DuckDBConnection], Optional[DuckDBConnection]]]" = None """A callable to be called after the connection is created.""" - connection_type: "type[DuckDBPyConnection]" = field(init=False, default_factory=lambda: DuckDBPyConnection) - """The type of connection to create. Defaults to DuckDBPyConnection.""" + connection_type: "type[DuckDBConnection]" = field(init=False, default_factory=lambda: DuckDBConnection) + """The type of connection to create. Defaults to DuckDBConnection.""" driver_type: "type[DuckDBDriver]" = field(init=False, default_factory=lambda: DuckDBDriver) # type: ignore[type-abstract,unused-ignore] """The type of driver to use. Defaults to DuckDBDriver.""" pool_instance: "None" = field(init=False, default=None) @@ -139,7 +138,7 @@ def __post_init__(self) -> None: raise ImproperConfigurationError(msg) from e self.extensions.extend(config_exts) - def _configure_connection(self, connection: "DuckDBPyConnection") -> None: + def _configure_connection(self, connection: "DuckDBConnection") -> None: """Configure the connection. Args: @@ -148,7 +147,7 @@ def _configure_connection(self, connection: "DuckDBPyConnection") -> None: for key, value in cast("dict[str,Any]", self.config).items(): connection.execute(f"SET {key}='{value}'") - def _configure_extensions(self, connection: "DuckDBPyConnection") -> None: + def _configure_extensions(self, connection: "DuckDBConnection") -> None: """Configure extensions for the connection. Args: @@ -165,7 +164,7 @@ def _configure_extensions(self, connection: "DuckDBPyConnection") -> None: connection.execute("update extensions") @staticmethod - def _secret_exists(connection: "DuckDBPyConnection", name: "str") -> bool: + def _secret_exists(connection: "DuckDBConnection", name: "str") -> bool: """Check if a secret exists in the connection. Args: @@ -179,7 +178,7 @@ def _secret_exists(connection: "DuckDBPyConnection", name: "str") -> bool: return results is not None @classmethod - def _is_community_extension(cls, connection: "DuckDBPyConnection", name: "str") -> bool: + def _is_community_extension(cls, connection: "DuckDBConnection", name: "str") -> bool: """Check if an extension is a community extension. Args: @@ -195,7 +194,7 @@ def _is_community_extension(cls, connection: "DuckDBPyConnection", name: "str") return results is None @classmethod - def _extension_installed(cls, connection: "DuckDBPyConnection", name: "str") -> bool: + def _extension_installed(cls, connection: "DuckDBConnection", name: "str") -> bool: """Check if a extension exists in the connection. Args: @@ -211,7 +210,7 @@ def _extension_installed(cls, connection: "DuckDBPyConnection", name: "str") -> return results is not None @classmethod - def _extension_loaded(cls, connection: "DuckDBPyConnection", name: "str") -> bool: + def _extension_loaded(cls, connection: "DuckDBConnection", name: "str") -> bool: """Check if a extension is loaded in the connection. Args: @@ -229,7 +228,7 @@ def _extension_loaded(cls, connection: "DuckDBPyConnection", name: "str") -> boo @classmethod def _configure_secrets( cls, - connection: "DuckDBPyConnection", + connection: "DuckDBConnection", secrets: "Sequence[SecretConfig]", ) -> None: """Configure persistent secrets for the connection. @@ -258,7 +257,7 @@ def _configure_secrets( raise ImproperConfigurationError(msg) from e @classmethod - def _configure_extension(cls, connection: "DuckDBPyConnection", extension: "ExtensionConfig") -> None: + def _configure_extension(cls, connection: "DuckDBConnection", extension: "ExtensionConfig") -> None: """Configure a single extension for the connection. Args: @@ -320,6 +319,7 @@ def connection_config_dict(self) -> "dict[str, Any]": "auto_update_extensions", "driver_type", "connection_type", + "connection_instance", }, convert_nested=False, ) @@ -327,7 +327,7 @@ def connection_config_dict(self) -> "dict[str, Any]": config["database"] = ":memory:" return config - def create_connection(self) -> "DuckDBPyConnection": + def create_connection(self) -> "DuckDBConnection": """Create and return a new database connection with configured extensions. Returns: @@ -349,11 +349,10 @@ def create_connection(self) -> "DuckDBPyConnection": except Exception as e: msg = f"Could not configure the DuckDB connection. Error: {e!s}" raise ImproperConfigurationError(msg) from e - else: - return connection + return connection @contextmanager - def provide_connection(self, *args: Any, **kwargs: Any) -> "Generator[DuckDBPyConnection, None, None]": + def provide_connection(self, *args: Any, **kwargs: Any) -> "Generator[DuckDBConnection, None, None]": """Create and provide a database connection. Yields: diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index 6f0238103..de409ac52 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -1,6 +1,8 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload +from duckdb import DuckDBPyConnection + from sqlspec.base import SyncDriverAdapterProtocol from sqlspec.mixins import SQLTranslatorMixin, SyncArrowBulkOperationsMixin from sqlspec.typing import ArrowTable, StatementParameterType @@ -8,36 +10,36 @@ if TYPE_CHECKING: from collections.abc import Generator, Sequence - from duckdb import DuckDBPyConnection - from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType, T -__all__ = ("DuckDBDriver",) +__all__ = ("DuckDBConnection", "DuckDBDriver") + +DuckDBConnection = DuckDBPyConnection class DuckDBDriver( - SyncArrowBulkOperationsMixin["DuckDBPyConnection"], - SQLTranslatorMixin["DuckDBPyConnection"], - SyncDriverAdapterProtocol["DuckDBPyConnection"], + SyncArrowBulkOperationsMixin["DuckDBConnection"], + SQLTranslatorMixin["DuckDBConnection"], + SyncDriverAdapterProtocol["DuckDBConnection"], ): """DuckDB Sync Driver Adapter.""" - connection: "DuckDBPyConnection" + connection: "DuckDBConnection" use_cursor: bool = True dialect: str = "duckdb" - def __init__(self, connection: "DuckDBPyConnection", use_cursor: bool = True) -> None: + def __init__(self, connection: "DuckDBConnection", use_cursor: bool = True) -> None: self.connection = connection self.use_cursor = use_cursor # --- Helper Methods --- # - def _cursor(self, connection: "DuckDBPyConnection") -> "DuckDBPyConnection": + def _cursor(self, connection: "DuckDBConnection") -> "DuckDBConnection": if self.use_cursor: return connection.cursor() return connection @contextmanager - def _with_cursor(self, connection: "DuckDBPyConnection") -> "Generator[DuckDBPyConnection, None, None]": + def _with_cursor(self, connection: "DuckDBConnection") -> "Generator[DuckDBConnection, None, None]": if self.use_cursor: cursor = self._cursor(connection) try: @@ -55,7 +57,7 @@ def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Sequence[dict[str, Any]]": ... @@ -66,7 +68,7 @@ def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Sequence[ModelDTOT]": ... @@ -76,7 +78,7 @@ def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": @@ -101,7 +103,7 @@ def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -112,7 +114,7 @@ def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -122,7 +124,7 @@ def select_one( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["DuckDBPyConnection"] = None, + connection: Optional["DuckDBConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": @@ -146,7 +148,7 @@ def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[dict[str, Any]]": ... @@ -157,7 +159,7 @@ def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Optional[ModelDTOT]": ... @@ -167,7 +169,7 @@ def select_one_or_none( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["DuckDBPyConnection"] = None, + connection: Optional["DuckDBConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": @@ -191,7 +193,7 @@ def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Any": ... @@ -202,7 +204,7 @@ def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "T": ... @@ -212,7 +214,7 @@ def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Union[T, Any]": @@ -233,7 +235,7 @@ def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[Any]": ... @@ -244,7 +246,7 @@ def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "Optional[T]": ... @@ -254,7 +256,7 @@ def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Optional[Union[T, Any]]": @@ -275,7 +277,7 @@ def insert_update_delete( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["DuckDBPyConnection"] = None, + connection: Optional["DuckDBConnection"] = None, **kwargs: Any, ) -> int: connection = self._connection(connection) @@ -291,7 +293,7 @@ def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -302,7 +304,7 @@ def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -312,7 +314,7 @@ def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": @@ -334,7 +336,7 @@ def execute_script( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["DuckDBPyConnection"] = None, + connection: Optional["DuckDBConnection"] = None, **kwargs: Any, ) -> str: connection = self._connection(connection) @@ -351,7 +353,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[DuckDBPyConnection]" = None, + connection: "Optional[DuckDBConnection]" = None, **kwargs: Any, ) -> "ArrowTable": connection = self._connection(connection) diff --git a/sqlspec/adapters/oracledb/__init__.py b/sqlspec/adapters/oracledb/__init__.py index 4a6af3d74..224a80ed9 100644 --- a/sqlspec/adapters/oracledb/__init__.py +++ b/sqlspec/adapters/oracledb/__init__.py @@ -4,13 +4,20 @@ OracleSyncConfig, OracleSyncPoolConfig, ) -from sqlspec.adapters.oracledb.driver import OracleAsyncDriver, OracleSyncDriver +from sqlspec.adapters.oracledb.driver import ( + OracleAsyncConnection, + OracleAsyncDriver, + OracleSyncConnection, + OracleSyncDriver, +) __all__ = ( "OracleAsyncConfig", + "OracleAsyncConnection", "OracleAsyncDriver", "OracleAsyncPoolConfig", "OracleSyncConfig", + "OracleSyncConnection", "OracleSyncDriver", "OracleSyncPoolConfig", ) diff --git a/sqlspec/adapters/oracledb/config/_asyncio.py b/sqlspec/adapters/oracledb/config/_asyncio.py index bb9c765fb..6e088d636 100644 --- a/sqlspec/adapters/oracledb/config/_asyncio.py +++ b/sqlspec/adapters/oracledb/config/_asyncio.py @@ -3,10 +3,9 @@ from typing import TYPE_CHECKING, Any, Optional, cast from oracledb import create_pool_async as oracledb_create_pool # pyright: ignore[reportUnknownVariableType] -from oracledb.connection import AsyncConnection from sqlspec.adapters.oracledb.config._common import OracleGenericPoolConfig -from sqlspec.adapters.oracledb.driver import OracleAsyncDriver +from sqlspec.adapters.oracledb.driver import OracleAsyncConnection, OracleAsyncDriver from sqlspec.base import AsyncDatabaseConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import dataclass_to_dict @@ -24,12 +23,12 @@ @dataclass -class OracleAsyncPoolConfig(OracleGenericPoolConfig["AsyncConnection", "AsyncConnectionPool"]): +class OracleAsyncPoolConfig(OracleGenericPoolConfig["OracleAsyncConnection", "AsyncConnectionPool"]): """Async Oracle Pool Config""" @dataclass -class OracleAsyncConfig(AsyncDatabaseConfig["AsyncConnection", "AsyncConnectionPool", "OracleAsyncDriver"]): +class OracleAsyncConfig(AsyncDatabaseConfig["OracleAsyncConnection", "AsyncConnectionPool", "OracleAsyncDriver"]): """Oracle Async database Configuration. This class provides the base configuration for Oracle database connections, extending @@ -49,7 +48,7 @@ class OracleAsyncConfig(AsyncDatabaseConfig["AsyncConnection", "AsyncConnectionP If set, the plugin will use the provided pool rather than instantiate one. """ - connection_type: "type[AsyncConnection]" = field(init=False, default_factory=lambda: AsyncConnection) + connection_type: "type[OracleAsyncConnection]" = field(init=False, default_factory=lambda: OracleAsyncConnection) """Connection class to use. Defaults to :class:`AsyncConnection`. @@ -111,7 +110,7 @@ def pool_config_dict(self) -> "dict[str, Any]": msg = "'pool_config' methods can not be used when a 'pool_instance' is provided." raise ImproperConfigurationError(msg) - async def create_connection(self) -> "AsyncConnection": + async def create_connection(self) -> "OracleAsyncConnection": """Create and return a new oracledb async connection from the pool. Returns: @@ -122,7 +121,7 @@ async def create_connection(self) -> "AsyncConnection": """ try: pool = await self.provide_pool() - return cast("AsyncConnection", await pool.acquire()) # type: ignore[no-any-return,unused-ignore] + return cast("OracleAsyncConnection", await pool.acquire()) # type: ignore[no-any-return,unused-ignore] except Exception as e: msg = f"Could not configure the Oracle async connection. Error: {e!s}" raise ImproperConfigurationError(msg) from e @@ -160,7 +159,7 @@ def provide_pool(self, *args: "Any", **kwargs: "Any") -> "Awaitable[AsyncConnect return self.create_pool() @asynccontextmanager - async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[AsyncConnection, None]": + async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[OracleAsyncConnection, None]": """Create a connection instance. Yields: diff --git a/sqlspec/adapters/oracledb/config/_sync.py b/sqlspec/adapters/oracledb/config/_sync.py index 35a87e167..4b300db28 100644 --- a/sqlspec/adapters/oracledb/config/_sync.py +++ b/sqlspec/adapters/oracledb/config/_sync.py @@ -3,10 +3,9 @@ from typing import TYPE_CHECKING, Any, Optional from oracledb import create_pool as oracledb_create_pool # pyright: ignore[reportUnknownVariableType] -from oracledb.connection import Connection from sqlspec.adapters.oracledb.config._common import OracleGenericPoolConfig -from sqlspec.adapters.oracledb.driver import OracleSyncDriver +from sqlspec.adapters.oracledb.driver import OracleSyncConnection, OracleSyncDriver from sqlspec.base import SyncDatabaseConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import dataclass_to_dict @@ -24,12 +23,12 @@ @dataclass -class OracleSyncPoolConfig(OracleGenericPoolConfig["Connection", "ConnectionPool"]): +class OracleSyncPoolConfig(OracleGenericPoolConfig["OracleSyncConnection", "ConnectionPool"]): """Sync Oracle Pool Config""" @dataclass -class OracleSyncConfig(SyncDatabaseConfig["Connection", "ConnectionPool", "OracleSyncDriver"]): +class OracleSyncConfig(SyncDatabaseConfig["OracleSyncConnection", "ConnectionPool", "OracleSyncDriver"]): """Oracle Sync database Configuration. This class provides the base configuration for Oracle database connections, extending @@ -49,7 +48,7 @@ class OracleSyncConfig(SyncDatabaseConfig["Connection", "ConnectionPool", "Oracl If set, the plugin will use the provided pool rather than instantiate one. """ - connection_type: "type[Connection]" = field(init=False, default_factory=lambda: Connection) # pyright: ignore + connection_type: "type[OracleSyncConnection]" = field(init=False, default_factory=lambda: OracleSyncConnection) # pyright: ignore """Connection class to use. Defaults to :class:`Connection`. @@ -111,7 +110,7 @@ def pool_config_dict(self) -> "dict[str, Any]": msg = "'pool_config' methods can not be used when a 'pool_instance' is provided." raise ImproperConfigurationError(msg) - def create_connection(self) -> "Connection": + def create_connection(self) -> "OracleSyncConnection": """Create and return a new oracledb connection from the pool. Returns: @@ -160,7 +159,7 @@ def provide_pool(self, *args: "Any", **kwargs: "Any") -> "ConnectionPool": return self.create_pool() @contextmanager - def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[Connection, None, None]": + def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[OracleSyncConnection, None, None]": """Create a connection instance. Yields: diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 5d43d8757..52c2e2a02 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -1,6 +1,8 @@ from contextlib import asynccontextmanager, contextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload +from oracledb import AsyncConnection, AsyncCursor, Connection, Cursor + from sqlspec.base import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol from sqlspec.mixins import AsyncArrowBulkOperationsMixin, SQLTranslatorMixin, SyncArrowBulkOperationsMixin from sqlspec.typing import ArrowTable, StatementParameterType, T @@ -8,30 +10,30 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator, Sequence - from oracledb import AsyncConnection, AsyncCursor, Connection, Cursor - - # Conditionally import ArrowTable for type checking from sqlspec.typing import ModelDTOT -__all__ = ("OracleAsyncDriver", "OracleSyncDriver") +__all__ = ("OracleAsyncConnection", "OracleAsyncDriver", "OracleSyncConnection", "OracleSyncDriver") + +OracleSyncConnection = Connection +OracleAsyncConnection = AsyncConnection class OracleSyncDriver( - SyncArrowBulkOperationsMixin["Connection"], - SQLTranslatorMixin["Connection"], - SyncDriverAdapterProtocol["Connection"], + SyncArrowBulkOperationsMixin["OracleSyncConnection"], + SQLTranslatorMixin["OracleSyncConnection"], + SyncDriverAdapterProtocol["OracleSyncConnection"], ): """Oracle Sync Driver Adapter.""" - connection: "Connection" + connection: "OracleSyncConnection" dialect: str = "oracle" - def __init__(self, connection: "Connection") -> None: + def __init__(self, connection: "OracleSyncConnection") -> None: self.connection = connection @staticmethod @contextmanager - def _with_cursor(connection: "Connection") -> "Generator[Cursor, None, None]": + def _with_cursor(connection: "OracleSyncConnection") -> "Generator[Cursor, None, None]": cursor = connection.cursor() try: yield cursor @@ -46,7 +48,7 @@ def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Sequence[dict[str, Any]]": ... @@ -57,7 +59,7 @@ def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Sequence[ModelDTOT]": ... @@ -67,7 +69,7 @@ def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": @@ -105,7 +107,7 @@ def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -116,7 +118,7 @@ def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -126,7 +128,7 @@ def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": @@ -165,7 +167,7 @@ def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[dict[str, Any]]": ... @@ -176,7 +178,7 @@ def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Optional[ModelDTOT]": ... @@ -186,7 +188,7 @@ def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": @@ -220,7 +222,7 @@ def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Any": ... @@ -231,7 +233,7 @@ def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "T": ... @@ -241,7 +243,7 @@ def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Union[T, Any]": @@ -269,7 +271,7 @@ def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[Any]": ... @@ -280,7 +282,7 @@ def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "Optional[T]": ... @@ -290,7 +292,7 @@ def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Optional[Union[T, Any]]": @@ -319,7 +321,7 @@ def insert_update_delete( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -341,7 +343,7 @@ def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -352,7 +354,7 @@ def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -362,7 +364,7 @@ def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": @@ -395,7 +397,7 @@ def execute_script( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, **kwargs: Any, ) -> str: """Execute a script. @@ -416,7 +418,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[OracleSyncConnection]" = None, **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] """Execute a SQL query and return results as an Apache Arrow Table. @@ -432,21 +434,21 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] class OracleAsyncDriver( - AsyncArrowBulkOperationsMixin["AsyncConnection"], - SQLTranslatorMixin["AsyncConnection"], - AsyncDriverAdapterProtocol["AsyncConnection"], + AsyncArrowBulkOperationsMixin["OracleAsyncConnection"], + SQLTranslatorMixin["OracleAsyncConnection"], + AsyncDriverAdapterProtocol["OracleAsyncConnection"], ): """Oracle Async Driver Adapter.""" - connection: "AsyncConnection" + connection: "OracleAsyncConnection" dialect: str = "oracle" - def __init__(self, connection: "AsyncConnection") -> None: + def __init__(self, connection: "OracleAsyncConnection") -> None: self.connection = connection @staticmethod @asynccontextmanager - async def _with_cursor(connection: "AsyncConnection") -> "AsyncGenerator[AsyncCursor, None]": + async def _with_cursor(connection: "OracleAsyncConnection") -> "AsyncGenerator[AsyncCursor, None]": cursor = connection.cursor() try: yield cursor @@ -461,7 +463,7 @@ async def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Sequence[dict[str, Any]]": ... @@ -472,7 +474,7 @@ async def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Sequence[ModelDTOT]": ... @@ -482,7 +484,7 @@ async def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": @@ -514,7 +516,7 @@ async def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -525,7 +527,7 @@ async def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -535,7 +537,7 @@ async def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": @@ -566,7 +568,7 @@ async def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[dict[str, Any]]": ... @@ -577,7 +579,7 @@ async def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Optional[ModelDTOT]": ... @@ -587,7 +589,7 @@ async def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": @@ -621,7 +623,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Any": ... @@ -632,7 +634,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "T": ... @@ -642,7 +644,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Union[T, Any]": @@ -670,7 +672,7 @@ async def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[Any]": ... @@ -681,7 +683,7 @@ async def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "Optional[T]": ... @@ -691,7 +693,7 @@ async def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Optional[Union[T, Any]]": @@ -720,7 +722,7 @@ async def insert_update_delete( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -742,7 +744,7 @@ async def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -753,7 +755,7 @@ async def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -763,7 +765,7 @@ async def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": @@ -796,7 +798,7 @@ async def execute_script( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, **kwargs: Any, ) -> str: """Execute a script. @@ -817,7 +819,7 @@ async def select_arrow( # pyright: ignore[reportUnknownParameterType] parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[OracleAsyncConnection]" = None, **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] """Execute a SQL query asynchronously and return results as an Apache Arrow Table. diff --git a/sqlspec/adapters/psqlpy/__init__.py b/sqlspec/adapters/psqlpy/__init__.py index e69de29bb..a48b7109e 100644 --- a/sqlspec/adapters/psqlpy/__init__.py +++ b/sqlspec/adapters/psqlpy/__init__.py @@ -0,0 +1,9 @@ +from sqlspec.adapters.psqlpy.config import PsqlpyConfig, PsqlpyPoolConfig +from sqlspec.adapters.psqlpy.driver import PsqlpyConnection, PsqlpyDriver + +__all__ = ( + "PsqlpyConfig", + "PsqlpyConnection", + "PsqlpyDriver", + "PsqlpyPoolConfig", +) diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index 193089714..dac82c5a1 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -6,7 +6,7 @@ from psqlpy import Connection, ConnectionPool -from sqlspec.adapters.psqlpy.driver import PsqlpyDriver +from sqlspec.adapters.psqlpy.driver import PsqlpyConnection, PsqlpyDriver from sqlspec.base import AsyncDatabaseConfig, GenericPoolConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import Empty, EmptyType, dataclass_to_dict @@ -94,7 +94,7 @@ class PsqlpyPoolConfig(GenericPoolConfig): @dataclass -class PsqlpyConfig(AsyncDatabaseConfig[Connection, ConnectionPool, PsqlpyDriver]): +class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyDriver]): """Configuration for psqlpy database connections, managing a connection pool. This configuration class wraps `PsqlpyPoolConfig` and manages the lifecycle @@ -105,7 +105,7 @@ class PsqlpyConfig(AsyncDatabaseConfig[Connection, ConnectionPool, PsqlpyDriver] """Psqlpy Pool configuration""" driver_type: type[PsqlpyDriver] = field(default=PsqlpyDriver, init=False, hash=False) """Type of the driver object""" - connection_type: type[Connection] = field(default=Connection, init=False, hash=False) + connection_type: type[PsqlpyConnection] = field(default=PsqlpyConnection, init=False, hash=False) """Type of the connection object""" pool_instance: Optional[ConnectionPool] = field(default=None, hash=False) """The connection pool instance. If set, this will be used instead of creating a new pool.""" @@ -204,7 +204,7 @@ async def _create() -> "ConnectionPool": return _create() - def create_connection(self) -> "Awaitable[Connection]": + def create_connection(self) -> "Awaitable[PsqlpyConnection]": """Create and return a new, standalone psqlpy connection using the configured parameters. Returns: @@ -222,7 +222,7 @@ async def _create() -> "Connection": return _create() @asynccontextmanager - async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[Connection, None]": + async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[PsqlpyConnection, None]": """Acquire a connection from the pool. Yields: diff --git a/sqlspec/adapters/psqlpy/driver.py b/sqlspec/adapters/psqlpy/driver.py index 9fef18717..c38dcad1c 100644 --- a/sqlspec/adapters/psqlpy/driver.py +++ b/sqlspec/adapters/psqlpy/driver.py @@ -5,6 +5,7 @@ import re from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload +from psqlpy import Connection, QueryResult from psqlpy.exceptions import RustPSQLDriverPyBaseError from sqlspec.base import AsyncDriverAdapterProtocol @@ -15,13 +16,14 @@ if TYPE_CHECKING: from collections.abc import Sequence - from psqlpy import Connection, QueryResult + from psqlpy import QueryResult from sqlspec.typing import ModelDTOT, StatementParameterType, T -__all__ = ("PsqlpyDriver",) +__all__ = ("PsqlpyConnection", "PsqlpyDriver") +PsqlpyConnection = Connection # Regex to find '?' placeholders, skipping those inside quotes or SQL comments QMARK_REGEX = re.compile( r"""(?P"[^"]*") | # Double-quoted strings @@ -35,15 +37,15 @@ class PsqlpyDriver( - SQLTranslatorMixin["Connection"], - AsyncDriverAdapterProtocol["Connection"], + SQLTranslatorMixin["PsqlpyConnection"], + AsyncDriverAdapterProtocol["PsqlpyConnection"], ): """Psqlpy Postgres Driver Adapter.""" - connection: "Connection" + connection: "PsqlpyConnection" dialect: str = "postgres" - def __init__(self, connection: "Connection") -> None: + def __init__(self, connection: "PsqlpyConnection") -> None: self.connection = connection def _process_sql_params( @@ -179,7 +181,7 @@ async def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsqlpyConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Sequence[dict[str, Any]]": ... @@ -190,7 +192,7 @@ async def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsqlpyConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Sequence[ModelDTOT]": ... @@ -200,7 +202,7 @@ async def select( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["PsqlpyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": @@ -221,7 +223,7 @@ async def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsqlpyConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -232,7 +234,7 @@ async def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsqlpyConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -242,7 +244,7 @@ async def select_one( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["PsqlpyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": @@ -264,7 +266,7 @@ async def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsqlpyConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[dict[str, Any]]": ... @@ -275,7 +277,7 @@ async def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsqlpyConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Optional[ModelDTOT]": ... @@ -285,7 +287,7 @@ async def select_one_or_none( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["PsqlpyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": @@ -311,7 +313,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsqlpyConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Any": ... @@ -322,7 +324,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsqlpyConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "T": ... @@ -332,7 +334,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsqlpyConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Union[T, Any]": @@ -353,7 +355,7 @@ async def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsqlpyConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[Any]": ... @@ -364,7 +366,7 @@ async def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsqlpyConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "Optional[T]": ... @@ -374,7 +376,7 @@ async def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsqlpyConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Optional[Union[T, Any]]": @@ -398,7 +400,7 @@ async def insert_update_delete( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["PsqlpyConnection"] = None, **kwargs: Any, ) -> int: connection = self._connection(connection) @@ -417,7 +419,7 @@ async def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsqlpyConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -428,7 +430,7 @@ async def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsqlpyConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -438,7 +440,7 @@ async def insert_update_delete_returning( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["PsqlpyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": @@ -463,7 +465,7 @@ async def execute_script( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["PsqlpyConnection"] = None, **kwargs: Any, ) -> str: connection = self._connection(connection) @@ -473,7 +475,7 @@ async def execute_script( await connection.execute(sql, parameters=parameters) return sql - def _connection(self, connection: Optional["Connection"] = None) -> "Connection": + def _connection(self, connection: Optional["PsqlpyConnection"] = None) -> "PsqlpyConnection": """Get the connection to use. Args: diff --git a/sqlspec/adapters/psycopg/__init__.py b/sqlspec/adapters/psycopg/__init__.py index 1105f5b43..fb7ccf8ef 100644 --- a/sqlspec/adapters/psycopg/__init__.py +++ b/sqlspec/adapters/psycopg/__init__.py @@ -4,13 +4,20 @@ PsycopgSyncConfig, PsycopgSyncPoolConfig, ) -from sqlspec.adapters.psycopg.driver import PsycopgAsyncDriver, PsycopgSyncDriver +from sqlspec.adapters.psycopg.driver import ( + PsycopgAsyncConnection, + PsycopgAsyncDriver, + PsycopgSyncConnection, + PsycopgSyncDriver, +) __all__ = ( "PsycopgAsyncConfig", + "PsycopgAsyncConnection", "PsycopgAsyncDriver", "PsycopgAsyncPoolConfig", "PsycopgSyncConfig", + "PsycopgSyncConnection", "PsycopgSyncDriver", "PsycopgSyncPoolConfig", ) diff --git a/sqlspec/adapters/psycopg/config/__init__.py b/sqlspec/adapters/psycopg/config/__init__.py index a3ab74788..9c9277bcb 100644 --- a/sqlspec/adapters/psycopg/config/__init__.py +++ b/sqlspec/adapters/psycopg/config/__init__.py @@ -1,9 +1,19 @@ from sqlspec.adapters.psycopg.config._async import PsycopgAsyncConfig, PsycopgAsyncPoolConfig from sqlspec.adapters.psycopg.config._sync import PsycopgSyncConfig, PsycopgSyncPoolConfig +from sqlspec.adapters.psycopg.driver import ( + PsycopgAsyncConnection, + PsycopgAsyncDriver, + PsycopgSyncConnection, + PsycopgSyncDriver, +) __all__ = ( "PsycopgAsyncConfig", + "PsycopgAsyncConnection", + "PsycopgAsyncDriver", "PsycopgAsyncPoolConfig", "PsycopgSyncConfig", + "PsycopgSyncConnection", + "PsycopgSyncDriver", "PsycopgSyncPoolConfig", ) diff --git a/sqlspec/adapters/psycopg/config/_async.py b/sqlspec/adapters/psycopg/config/_async.py index a0f0569de..a301c7d6a 100644 --- a/sqlspec/adapters/psycopg/config/_async.py +++ b/sqlspec/adapters/psycopg/config/_async.py @@ -2,11 +2,10 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional -from psycopg import AsyncConnection from psycopg_pool import AsyncConnectionPool from sqlspec.adapters.psycopg.config._common import PsycopgGenericPoolConfig -from sqlspec.adapters.psycopg.driver import PsycopgAsyncDriver +from sqlspec.adapters.psycopg.driver import PsycopgAsyncConnection, PsycopgAsyncDriver from sqlspec.base import AsyncDatabaseConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import dataclass_to_dict @@ -22,12 +21,12 @@ @dataclass -class PsycopgAsyncPoolConfig(PsycopgGenericPoolConfig[AsyncConnection, AsyncConnectionPool]): +class PsycopgAsyncPoolConfig(PsycopgGenericPoolConfig[PsycopgAsyncConnection, AsyncConnectionPool]): """Async Psycopg Pool Config""" @dataclass -class PsycopgAsyncConfig(AsyncDatabaseConfig[AsyncConnection, AsyncConnectionPool, PsycopgAsyncDriver]): +class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnectionPool, PsycopgAsyncDriver]): """Async Psycopg database Configuration. This class provides the base configuration for Psycopg database connections, extending @@ -41,7 +40,7 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[AsyncConnection, AsyncConnectionPoo """Psycopg Pool configuration""" pool_instance: "Optional[AsyncConnectionPool]" = None """Optional pool to use""" - connection_type: "type[AsyncConnection]" = field(init=False, default_factory=lambda: AsyncConnection) # type: ignore[assignment] + connection_type: "type[PsycopgAsyncConnection]" = field(init=False, default_factory=lambda: PsycopgAsyncConnection) # type: ignore[assignment] """Type of the connection object""" driver_type: "type[PsycopgAsyncDriver]" = field(init=False, default_factory=lambda: PsycopgAsyncDriver) # type: ignore[type-abstract,unused-ignore] """Type of the driver object""" @@ -93,7 +92,7 @@ def pool_config_dict(self) -> "dict[str, Any]": msg = "'pool_config' methods can not be used when a 'pool_instance' is provided." raise ImproperConfigurationError(msg) - async def create_connection(self) -> "AsyncConnection": + async def create_connection(self) -> "PsycopgAsyncConnection": """Create and return a new psycopg async connection from the pool. Returns: @@ -143,7 +142,7 @@ def provide_pool(self, *args: "Any", **kwargs: "Any") -> "Awaitable[AsyncConnect return self.create_pool() @asynccontextmanager - async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[AsyncConnection, None]": + async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[PsycopgAsyncConnection, None]": """Create and provide a database connection. Yields: diff --git a/sqlspec/adapters/psycopg/config/_sync.py b/sqlspec/adapters/psycopg/config/_sync.py index 75ea340c8..5eddaf359 100644 --- a/sqlspec/adapters/psycopg/config/_sync.py +++ b/sqlspec/adapters/psycopg/config/_sync.py @@ -2,11 +2,10 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional -from psycopg import Connection from psycopg_pool import ConnectionPool from sqlspec.adapters.psycopg.config._common import PsycopgGenericPoolConfig -from sqlspec.adapters.psycopg.driver import PsycopgSyncDriver +from sqlspec.adapters.psycopg.driver import PsycopgSyncConnection, PsycopgSyncDriver from sqlspec.base import SyncDatabaseConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import dataclass_to_dict @@ -22,12 +21,12 @@ @dataclass -class PsycopgSyncPoolConfig(PsycopgGenericPoolConfig[Connection, ConnectionPool]): +class PsycopgSyncPoolConfig(PsycopgGenericPoolConfig[PsycopgSyncConnection, ConnectionPool]): """Sync Psycopg Pool Config""" @dataclass -class PsycopgSyncConfig(SyncDatabaseConfig[Connection, ConnectionPool, PsycopgSyncDriver]): +class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool, PsycopgSyncDriver]): """Sync Psycopg database Configuration. This class provides the base configuration for Psycopg database connections, extending the generic database configuration with Psycopg-specific settings.([1](https://www.psycopg.org/psycopg3/docs/api/connections.html)) @@ -40,7 +39,7 @@ class PsycopgSyncConfig(SyncDatabaseConfig[Connection, ConnectionPool, PsycopgSy """Psycopg Pool configuration""" pool_instance: "Optional[ConnectionPool]" = None """Optional pool to use""" - connection_type: "type[Connection]" = field(init=False, default_factory=lambda: Connection) # type: ignore[assignment] + connection_type: "type[PsycopgSyncConnection]" = field(init=False, default_factory=lambda: PsycopgSyncConnection) # type: ignore[assignment] """Type of the connection object""" driver_type: "type[PsycopgSyncDriver]" = field(init=False, default_factory=lambda: PsycopgSyncDriver) # type: ignore[type-abstract,unused-ignore] """Type of the driver object""" @@ -92,7 +91,7 @@ def pool_config_dict(self) -> "dict[str, Any]": msg = "'pool_config' methods can not be used when a 'pool_instance' is provided." raise ImproperConfigurationError(msg) - def create_connection(self) -> "Connection": + def create_connection(self) -> "PsycopgSyncConnection": """Create and return a new psycopg connection from the pool. Returns: @@ -142,11 +141,11 @@ def provide_pool(self, *args: "Any", **kwargs: "Any") -> "ConnectionPool": return self.create_pool() @contextmanager - def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[Connection, None, None]": + def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[PsycopgSyncConnection, None, None]": """Create and provide a database connection. Yields: - Connection: A database connection from the pool. + PsycopgSyncConnection: A database connection from the pool. """ pool = self.provide_pool(*args, **kwargs) with pool, pool.connection() as connection: diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index 4d94c3bbe..ab7760495 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -2,6 +2,7 @@ from contextlib import asynccontextmanager, contextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload +from psycopg import AsyncConnection, Connection from psycopg.rows import dict_row from sqlspec.base import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol @@ -12,13 +13,14 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator, Sequence - from psycopg import AsyncConnection, Connection - from sqlspec.typing import ModelDTOT, StatementParameterType, T logger = logging.getLogger("sqlspec") -__all__ = ("PsycopgAsyncDriver", "PsycopgSyncDriver") +__all__ = ("PsycopgAsyncConnection", "PsycopgAsyncDriver", "PsycopgSyncConnection", "PsycopgSyncDriver") + +PsycopgSyncConnection = Connection +PsycopgAsyncConnection = AsyncConnection class PsycopgDriverBase: @@ -79,15 +81,15 @@ def _process_sql_params( class PsycopgSyncDriver( PsycopgDriverBase, - SQLTranslatorMixin["Connection"], - SyncDriverAdapterProtocol["Connection"], + SQLTranslatorMixin["PsycopgSyncConnection"], + SyncDriverAdapterProtocol["PsycopgSyncConnection"], ): """Psycopg Sync Driver Adapter.""" - connection: "Connection" + connection: "PsycopgSyncConnection" dialect: str = "postgres" - def __init__(self, connection: "Connection") -> None: + def __init__(self, connection: "PsycopgSyncConnection") -> None: self.connection = connection def _process_sql_params( @@ -97,7 +99,6 @@ def _process_sql_params( /, **kwargs: Any, ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL and parameters, converting :name -> %(name)s if needed.""" stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) processed_sql, processed_params = stmt.process() @@ -144,7 +145,7 @@ def _process_sql_params( @staticmethod @contextmanager - def _with_cursor(connection: "Connection") -> "Generator[Any, None, None]": + def _with_cursor(connection: "PsycopgSyncConnection") -> "Generator[Any, None, None]": cursor = connection.cursor(row_factory=dict_row) try: yield cursor @@ -159,7 +160,7 @@ def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Sequence[dict[str, Any]]": ... @@ -170,7 +171,7 @@ def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Sequence[ModelDTOT]": ... @@ -181,7 +182,7 @@ def select( /, *, schema_type: "Optional[type[ModelDTOT]]" = None, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, **kwargs: Any, ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -208,7 +209,7 @@ def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -219,7 +220,7 @@ def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -229,7 +230,7 @@ def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": @@ -255,7 +256,7 @@ def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[dict[str, Any]]": ... @@ -266,7 +267,7 @@ def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Optional[ModelDTOT]": ... @@ -276,7 +277,7 @@ def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": @@ -303,7 +304,7 @@ def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Any": ... @@ -314,7 +315,7 @@ def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "T": ... @@ -324,7 +325,7 @@ def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Union[T, Any]": @@ -352,7 +353,7 @@ def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[Any]": ... @@ -363,7 +364,7 @@ def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "Optional[T]": ... @@ -373,7 +374,7 @@ def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Optional[Union[T, Any]]": @@ -402,7 +403,7 @@ def insert_update_delete( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, **kwargs: Any, ) -> int: """Execute an INSERT, UPDATE, or DELETE query and return the number of affected rows. @@ -423,7 +424,7 @@ def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -434,7 +435,7 @@ def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -444,7 +445,7 @@ def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": @@ -472,7 +473,7 @@ def execute_script( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[PsycopgSyncConnection]" = None, **kwargs: Any, ) -> str: """Execute a script. @@ -489,20 +490,20 @@ def execute_script( class PsycopgAsyncDriver( PsycopgDriverBase, - SQLTranslatorMixin["AsyncConnection"], - AsyncDriverAdapterProtocol["AsyncConnection"], + SQLTranslatorMixin["PsycopgAsyncConnection"], + AsyncDriverAdapterProtocol["PsycopgAsyncConnection"], ): """Psycopg Async Driver Adapter.""" - connection: "AsyncConnection" + connection: "PsycopgAsyncConnection" dialect: str = "postgres" - def __init__(self, connection: "AsyncConnection") -> None: + def __init__(self, connection: "PsycopgAsyncConnection") -> None: self.connection = connection @staticmethod @asynccontextmanager - async def _with_cursor(connection: "AsyncConnection") -> "AsyncGenerator[Any, None]": + async def _with_cursor(connection: "PsycopgAsyncConnection") -> "AsyncGenerator[Any, None]": cursor = connection.cursor(row_factory=dict_row) try: yield cursor @@ -517,7 +518,7 @@ async def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Sequence[dict[str, Any]]": ... @@ -528,7 +529,7 @@ async def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Sequence[ModelDTOT]": ... @@ -538,7 +539,7 @@ async def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": @@ -567,7 +568,7 @@ async def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -578,7 +579,7 @@ async def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -588,7 +589,7 @@ async def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": @@ -615,7 +616,7 @@ async def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[dict[str, Any]]": ... @@ -626,7 +627,7 @@ async def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Optional[ModelDTOT]": ... @@ -637,7 +638,7 @@ async def select_one_or_none( /, *, schema_type: "Optional[type[ModelDTOT]]" = None, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -664,7 +665,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Any": ... @@ -675,7 +676,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "T": ... @@ -685,7 +686,7 @@ async def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Union[T, Any]": @@ -713,7 +714,7 @@ async def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Optional[Union[T, Any]]": @@ -743,7 +744,7 @@ async def insert_update_delete( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, **kwargs: Any, ) -> int: """Execute an INSERT, UPDATE, or DELETE query and return the number of affected rows. @@ -769,7 +770,7 @@ async def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -780,7 +781,7 @@ async def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -790,7 +791,7 @@ async def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": @@ -819,7 +820,7 @@ async def execute_script( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[AsyncConnection]" = None, + connection: "Optional[PsycopgAsyncConnection]" = None, **kwargs: Any, ) -> str: """Execute a script. diff --git a/sqlspec/adapters/sqlite/__init__.py b/sqlspec/adapters/sqlite/__init__.py index a97d890a5..ad7d4658e 100644 --- a/sqlspec/adapters/sqlite/__init__.py +++ b/sqlspec/adapters/sqlite/__init__.py @@ -1,7 +1,8 @@ from sqlspec.adapters.sqlite.config import SqliteConfig -from sqlspec.adapters.sqlite.driver import SqliteDriver +from sqlspec.adapters.sqlite.driver import SqliteConnection, SqliteDriver __all__ = ( "SqliteConfig", + "SqliteConnection", "SqliteDriver", ) diff --git a/sqlspec/adapters/sqlite/config.py b/sqlspec/adapters/sqlite/config.py index 087306a65..71c0a043b 100644 --- a/sqlspec/adapters/sqlite/config.py +++ b/sqlspec/adapters/sqlite/config.py @@ -1,9 +1,9 @@ +import sqlite3 from contextlib import contextmanager from dataclasses import dataclass, field -from sqlite3 import Connection from typing import TYPE_CHECKING, Any, Literal, Optional, Union -from sqlspec.adapters.sqlite.driver import SqliteDriver +from sqlspec.adapters.sqlite.driver import SqliteConnection, SqliteDriver from sqlspec.base import NoPoolSyncConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.typing import Empty, EmptyType, dataclass_to_dict @@ -16,7 +16,7 @@ @dataclass -class SqliteConfig(NoPoolSyncConfig["Connection", "SqliteDriver"]): +class SqliteConfig(NoPoolSyncConfig["SqliteConnection", "SqliteDriver"]): """Configuration for SQLite database connections. This class provides configuration options for SQLite database connections, wrapping all parameters @@ -40,7 +40,7 @@ class SqliteConfig(NoPoolSyncConfig["Connection", "SqliteDriver"]): check_same_thread: "Union[bool, EmptyType]" = Empty """If True (default), ProgrammingError is raised if the database connection is used by a thread other than the one that created it. If False, the connection may be shared across multiple threads.""" - factory: "Union[type[Connection], EmptyType]" = Empty + factory: "Union[type[SqliteConnection], EmptyType]" = Empty """A custom Connection class factory. If given, must be a callable that returns a Connection instance.""" cached_statements: "Union[int, EmptyType]" = Empty @@ -50,7 +50,7 @@ class SqliteConfig(NoPoolSyncConfig["Connection", "SqliteDriver"]): """If set to True, database is interpreted as a URI with supported options.""" driver_type: "type[SqliteDriver]" = field(init=False, default_factory=lambda: SqliteDriver) """Type of the driver object""" - connection_type: "type[Connection]" = field(init=False, default_factory=lambda: Connection) + connection_type: "type[SqliteConnection]" = field(init=False, default_factory=lambda: SqliteConnection) """Type of the connection object""" @property @@ -61,10 +61,13 @@ def connection_config_dict(self) -> "dict[str, Any]": A string keyed dict of config kwargs for the sqlite3.connect() function. """ return dataclass_to_dict( - self, exclude_empty=True, convert_nested=False, exclude={"pool_instance", "driver_type", "connection_type"} + self, + exclude_empty=True, + convert_nested=False, + exclude={"pool_instance", "driver_type", "connection_type"}, ) - def create_connection(self) -> "Connection": + def create_connection(self) -> "SqliteConnection": """Create and return a new database connection. Returns: @@ -73,8 +76,6 @@ def create_connection(self) -> "Connection": Raises: ImproperConfigurationError: If the connection could not be established. """ - import sqlite3 - try: return sqlite3.connect(**self.connection_config_dict) # type: ignore[no-any-return,unused-ignore] except Exception as e: @@ -82,7 +83,7 @@ def create_connection(self) -> "Connection": raise ImproperConfigurationError(msg) from e @contextmanager - def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[Connection, None, None]": + def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[SqliteConnection, None, None]": """Create and provide a database connection. Yields: @@ -100,7 +101,7 @@ def provide_session(self, *args: Any, **kwargs: Any) -> "Generator[SqliteDriver, """Create and provide a database connection. Yields: - A DuckDB driver instance. + A SQLite driver instance. """ diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index f78f00239..5f353ab24 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -1,5 +1,6 @@ +import sqlite3 from contextlib import contextmanager -from sqlite3 import Connection, Cursor +from sqlite3 import Cursor from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload from sqlspec.base import SyncDriverAdapterProtocol @@ -10,27 +11,29 @@ from sqlspec.typing import ModelDTOT, StatementParameterType, T -__all__ = ("SqliteDriver",) +__all__ = ("SqliteConnection", "SqliteDriver") + +SqliteConnection = sqlite3.Connection class SqliteDriver( - SQLTranslatorMixin["Connection"], - SyncDriverAdapterProtocol["Connection"], + SQLTranslatorMixin["SqliteConnection"], + SyncDriverAdapterProtocol["SqliteConnection"], ): """SQLite Sync Driver Adapter.""" - connection: "Connection" + connection: "SqliteConnection" dialect: str = "sqlite" - def __init__(self, connection: "Connection") -> None: + def __init__(self, connection: "SqliteConnection") -> None: self.connection = connection @staticmethod - def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> Cursor: + def _cursor(connection: "SqliteConnection", *args: Any, **kwargs: Any) -> Cursor: return connection.cursor(*args, **kwargs) # type: ignore[no-any-return] @contextmanager - def _with_cursor(self, connection: "Connection") -> "Generator[Cursor, None, None]": + def _with_cursor(self, connection: "SqliteConnection") -> "Generator[Cursor, None, None]": cursor = self._cursor(connection) try: yield cursor @@ -45,7 +48,7 @@ def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[SqliteConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Sequence[dict[str, Any]]": ... @@ -56,7 +59,7 @@ def select( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[SqliteConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Sequence[ModelDTOT]": ... @@ -66,7 +69,7 @@ def select( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["SqliteConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": @@ -97,7 +100,7 @@ def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[SqliteConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -108,7 +111,7 @@ def select_one( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[SqliteConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -118,7 +121,7 @@ def select_one( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["SqliteConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": @@ -148,7 +151,7 @@ def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[SqliteConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[dict[str, Any]]": ... @@ -159,7 +162,7 @@ def select_one_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[SqliteConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "Optional[ModelDTOT]": ... @@ -169,7 +172,7 @@ def select_one_or_none( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["SqliteConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": @@ -200,7 +203,7 @@ def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[SqliteConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Any": ... @@ -211,7 +214,7 @@ def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[SqliteConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "T": ... @@ -221,7 +224,7 @@ def select_value( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[SqliteConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Union[T, Any]": @@ -250,7 +253,7 @@ def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[SqliteConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "Optional[Any]": ... @@ -261,7 +264,7 @@ def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[SqliteConnection]" = None, schema_type: "type[T]", **kwargs: Any, ) -> "Optional[T]": ... @@ -271,7 +274,7 @@ def select_value_or_none( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[SqliteConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Optional[Union[T, Any]]": @@ -300,7 +303,7 @@ def insert_update_delete( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["SqliteConnection"] = None, **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -325,7 +328,7 @@ def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[SqliteConnection]" = None, schema_type: None = None, **kwargs: Any, ) -> "dict[str, Any]": ... @@ -336,7 +339,7 @@ def insert_update_delete_returning( parameters: "Optional[StatementParameterType]" = None, /, *, - connection: "Optional[Connection]" = None, + connection: "Optional[SqliteConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, ) -> "ModelDTOT": ... @@ -346,7 +349,7 @@ def insert_update_delete_returning( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["SqliteConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": @@ -390,7 +393,7 @@ def execute_script( parameters: Optional["StatementParameterType"] = None, /, *, - connection: Optional["Connection"] = None, + connection: Optional["SqliteConnection"] = None, **kwargs: Any, ) -> str: """Execute a script. diff --git a/sqlspec/extensions/litestar/plugin.py b/sqlspec/extensions/litestar/plugin.py index 11e6820f2..53e72a0b9 100644 --- a/sqlspec/extensions/litestar/plugin.py +++ b/sqlspec/extensions/litestar/plugin.py @@ -84,6 +84,9 @@ def on_app_init(self, app_config: "AppConfig") -> "AppConfig": ) for c in self._plugin_configs: c.annotation = self.add_config(c.config) + app_config.signature_types.append(c.annotation) + app_config.signature_types.append(c.config.connection_type) # type: ignore[union-attr] + app_config.signature_types.append(c.config.driver_type) # type: ignore[union-attr] app_config.before_send.append(c.before_send_handler) app_config.lifespan.append(c.lifespan_handler) # pyright: ignore[reportUnknownMemberType] app_config.dependencies.update( diff --git a/tests/unit/test_adapters/test_aiosqlite/test_config.py b/tests/unit/test_adapters/test_aiosqlite/test_config.py index bdae47871..d642780d0 100644 --- a/tests/unit/test_adapters/test_aiosqlite/test_config.py +++ b/tests/unit/test_adapters/test_aiosqlite/test_config.py @@ -33,7 +33,6 @@ def test_minimal_config() -> None: assert config.detect_types is Empty assert config.isolation_level is Empty assert config.check_same_thread is Empty - assert config.factory is Empty assert config.cached_statements is Empty assert config.uri is Empty @@ -46,7 +45,6 @@ def test_full_config() -> None: detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES, isolation_level="IMMEDIATE", check_same_thread=False, - factory=sqlite3.Connection, cached_statements=256, uri=True, ) @@ -56,7 +54,6 @@ def test_full_config() -> None: assert config.detect_types == sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES assert config.isolation_level == "IMMEDIATE" assert config.check_same_thread is False - assert config.factory == sqlite3.Connection assert config.cached_statements == 256 assert config.uri is True