Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sqlspec/adapters/adbc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sqlspec.adapters.adbc.config import AdbcConfig
from sqlspec.adapters.adbc.driver import AdbcDriver
from sqlspec.adapters.adbc.driver import AdbcConnection, AdbcDriver

__all__ = (
"AdbcConfig",
"AdbcConnection",
"AdbcDriver",
)
20 changes: 7 additions & 13 deletions sqlspec/adapters/adbc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast

from adbc_driver_manager.dbapi import Connection

from sqlspec.adapters.adbc.driver import AdbcDriver
from sqlspec.adapters.adbc.driver import AdbcConnection, AdbcDriver
from sqlspec.base import NoPoolSyncConfig
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.typing import Empty, EmptyType
Expand All @@ -18,7 +16,7 @@


@dataclass
class AdbcConfig(NoPoolSyncConfig["Connection", "AdbcDriver"]):
class AdbcConfig(NoPoolSyncConfig["AdbcConnection", "AdbcDriver"]):
"""Configuration for ADBC connections.

This class provides configuration options for ADBC database connections using the
Expand All @@ -33,20 +31,16 @@ class AdbcConfig(NoPoolSyncConfig["Connection", "AdbcDriver"]):
"""Additional database-specific connection parameters"""
conn_kwargs: "Optional[dict[str, Any]]" = None
"""Additional database-specific connection parameters"""
connection_type: "type[Connection]" = field(init=False, default_factory=lambda: Connection)
connection_type: "type[AdbcConnection]" = field(init=False, default_factory=lambda: AdbcConnection)
"""Type of the connection object"""
driver_type: "type[AdbcDriver]" = field(init=False, default_factory=lambda: AdbcDriver) # type: ignore[type-abstract,unused-ignore]
"""Type of the driver object"""
pool_instance: None = field(init=False, default=None)
pool_instance: None = field(init=False, default=None, hash=False)
"""No connection pool is used for ADBC connections"""
_is_in_memory: bool = field(init=False, default=False)
"""Flag indicating if the connection is for an in-memory database"""

def _set_adbc(self) -> str: # noqa: PLR0912
"""Identify the driver type based on the URI (if provided) or preset driver name.

Also sets the `_is_in_memory` flag for specific in-memory URIs.

Raises:
ImproperConfigurationError: If the driver name is not recognized or supported.

Expand Down Expand Up @@ -143,7 +137,7 @@ def connection_config_dict(self) -> "dict[str, Any]":
config["conn_kwargs"] = conn_kwargs
return config

def _get_connect_func(self) -> "Callable[..., Connection]":
def _get_connect_func(self) -> "Callable[..., AdbcConnection]":
self._set_adbc()
driver_path = cast("str", self.driver_name)
try:
Expand All @@ -166,7 +160,7 @@ def _get_connect_func(self) -> "Callable[..., Connection]":
raise ImproperConfigurationError(msg)
return connect_func # type: ignore[no-any-return]

def create_connection(self) -> "Connection":
def create_connection(self) -> "AdbcConnection":
"""Create and return a new database connection using the specific driver.

Returns:
Expand All @@ -189,7 +183,7 @@ def create_connection(self) -> "Connection":
raise ImproperConfigurationError(msg) from e

@contextmanager
def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[Connection, None, None]":
def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[AdbcConnection, None, None]":
"""Create and provide a database connection using the specific driver.

Yields:
Expand Down
62 changes: 32 additions & 30 deletions sqlspec/adapters/adbc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
if TYPE_CHECKING:
from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType, T

__all__ = ("AdbcDriver",)
__all__ = ("AdbcConnection", "AdbcDriver")

logger = logging.getLogger("sqlspec")

Expand All @@ -33,24 +33,26 @@
re.VERBOSE | re.DOTALL,
)

AdbcConnection = Connection


class AdbcDriver(
SyncArrowBulkOperationsMixin["Connection"],
SQLTranslatorMixin["Connection"],
SyncDriverAdapterProtocol["Connection"],
SyncArrowBulkOperationsMixin["AdbcConnection"],
SQLTranslatorMixin["AdbcConnection"],
SyncDriverAdapterProtocol["AdbcConnection"],
):
"""ADBC Sync Driver Adapter."""

connection: Connection
connection: AdbcConnection
__supports_arrow__: ClassVar[bool] = True

def __init__(self, connection: "Connection") -> None:
def __init__(self, connection: "AdbcConnection") -> None:
"""Initialize the ADBC driver adapter."""
self.connection = connection
self.dialect = self._get_dialect(connection)

@staticmethod
def _get_dialect(connection: "Connection") -> str: # noqa: PLR0911
def _get_dialect(connection: "AdbcConnection") -> str: # noqa: PLR0911
"""Get the database dialect based on the driver name.

Args:
Expand All @@ -75,11 +77,11 @@ def _get_dialect(connection: "Connection") -> str: # noqa: PLR0911
return "postgres" # default to postgresql dialect

@staticmethod
def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "Cursor":
def _cursor(connection: "AdbcConnection", *args: Any, **kwargs: Any) -> "Cursor":
return connection.cursor(*args, **kwargs)

@contextmanager
def _with_cursor(self, connection: "Connection") -> Generator["Cursor", None, None]:
def _with_cursor(self, connection: "AdbcConnection") -> Generator["Cursor", None, None]:
cursor = self._cursor(connection)
try:
yield cursor
Expand Down Expand Up @@ -172,7 +174,7 @@ def select(
parameters: "Optional[StatementParameterType]" = None,
/,
*,
connection: "Optional[Connection]" = None,
connection: "Optional[AdbcConnection]" = None,
schema_type: None = None,
**kwargs: Any,
) -> "Sequence[dict[str, Any]]": ...
Expand All @@ -183,7 +185,7 @@ def select(
parameters: "Optional[StatementParameterType]" = None,
/,
*,
connection: "Optional[Connection]" = None,
connection: "Optional[AdbcConnection]" = None,
schema_type: "type[ModelDTOT]",
**kwargs: Any,
) -> "Sequence[ModelDTOT]": ...
Expand All @@ -193,7 +195,7 @@ def select(
parameters: Optional["StatementParameterType"] = None,
/,
*,
connection: Optional["Connection"] = None,
connection: Optional["AdbcConnection"] = None,
schema_type: "Optional[type[ModelDTOT]]" = None,
**kwargs: Any,
) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]":
Expand Down Expand Up @@ -223,7 +225,7 @@ def select_one(
parameters: "Optional[StatementParameterType]" = None,
/,
*,
connection: "Optional[Connection]" = None,
connection: "Optional[AdbcConnection]" = None,
schema_type: None = None,
**kwargs: Any,
) -> "dict[str, Any]": ...
Expand All @@ -234,7 +236,7 @@ def select_one(
parameters: "Optional[StatementParameterType]" = None,
/,
*,
connection: "Optional[Connection]" = None,
connection: "Optional[AdbcConnection]" = None,
schema_type: "type[ModelDTOT]",
**kwargs: Any,
) -> "ModelDTOT": ...
Expand All @@ -244,7 +246,7 @@ def select_one(
parameters: Optional["StatementParameterType"] = None,
/,
*,
connection: Optional["Connection"] = None,
connection: Optional["AdbcConnection"] = None,
schema_type: "Optional[type[ModelDTOT]]" = None,
**kwargs: Any,
) -> "Union[ModelDTOT, dict[str, Any]]":
Expand All @@ -271,7 +273,7 @@ def select_one_or_none(
parameters: "Optional[StatementParameterType]" = None,
/,
*,
connection: "Optional[Connection]" = None,
connection: "Optional[AdbcConnection]" = None,
schema_type: None = None,
**kwargs: Any,
) -> "Optional[dict[str, Any]]": ...
Expand All @@ -282,7 +284,7 @@ def select_one_or_none(
parameters: "Optional[StatementParameterType]" = None,
/,
*,
connection: "Optional[Connection]" = None,
connection: "Optional[AdbcConnection]" = None,
schema_type: "type[ModelDTOT]",
**kwargs: Any,
) -> "Optional[ModelDTOT]": ...
Expand All @@ -292,7 +294,7 @@ def select_one_or_none(
parameters: Optional["StatementParameterType"] = None,
/,
*,
connection: Optional["Connection"] = None,
connection: Optional["AdbcConnection"] = None,
schema_type: "Optional[type[ModelDTOT]]" = None,
**kwargs: Any,
) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
Expand Down Expand Up @@ -320,7 +322,7 @@ def select_value(
parameters: "Optional[StatementParameterType]" = None,
/,
*,
connection: "Optional[Connection]" = None,
connection: "Optional[AdbcConnection]" = None,
schema_type: None = None,
**kwargs: Any,
) -> "Any": ...
Expand All @@ -331,7 +333,7 @@ def select_value(
parameters: "Optional[StatementParameterType]" = None,
/,
*,
connection: "Optional[Connection]" = None,
connection: "Optional[AdbcConnection]" = None,
schema_type: "type[T]",
**kwargs: Any,
) -> "T": ...
Expand All @@ -341,7 +343,7 @@ def select_value(
parameters: Optional["StatementParameterType"] = None,
/,
*,
connection: Optional["Connection"] = None,
connection: Optional["AdbcConnection"] = None,
schema_type: "Optional[type[T]]" = None,
**kwargs: Any,
) -> "Union[T, Any]":
Expand All @@ -367,7 +369,7 @@ def select_value_or_none(
parameters: "Optional[StatementParameterType]" = None,
/,
*,
connection: "Optional[Connection]" = None,
connection: "Optional[AdbcConnection]" = None,
schema_type: None = None,
**kwargs: Any,
) -> "Optional[Any]": ...
Expand All @@ -378,7 +380,7 @@ def select_value_or_none(
parameters: "Optional[StatementParameterType]" = None,
/,
*,
connection: "Optional[Connection]" = None,
connection: "Optional[AdbcConnection]" = None,
schema_type: "type[T]",
**kwargs: Any,
) -> "Optional[T]": ...
Expand All @@ -388,7 +390,7 @@ def select_value_or_none(
parameters: Optional["StatementParameterType"] = None,
/,
*,
connection: Optional["Connection"] = None,
connection: Optional["AdbcConnection"] = None,
schema_type: "Optional[type[T]]" = None,
**kwargs: Any,
) -> "Optional[Union[T, Any]]":
Expand All @@ -414,7 +416,7 @@ def insert_update_delete(
parameters: Optional["StatementParameterType"] = None,
/,
*,
connection: Optional["Connection"] = None,
connection: Optional["AdbcConnection"] = None,
**kwargs: Any,
) -> int:
"""Insert, update, or delete data from the database.
Expand All @@ -436,7 +438,7 @@ def insert_update_delete_returning(
parameters: "Optional[StatementParameterType]" = None,
/,
*,
connection: "Optional[Connection]" = None,
connection: "Optional[AdbcConnection]" = None,
schema_type: None = None,
**kwargs: Any,
) -> "dict[str, Any]": ...
Expand All @@ -447,7 +449,7 @@ def insert_update_delete_returning(
parameters: "Optional[StatementParameterType]" = None,
/,
*,
connection: "Optional[Connection]" = None,
connection: "Optional[AdbcConnection]" = None,
schema_type: "type[ModelDTOT]",
**kwargs: Any,
) -> "ModelDTOT": ...
Expand All @@ -457,7 +459,7 @@ def insert_update_delete_returning(
parameters: Optional["StatementParameterType"] = None,
/,
*,
connection: Optional["Connection"] = None,
connection: Optional["AdbcConnection"] = None,
schema_type: "Optional[type[ModelDTOT]]" = None,
**kwargs: Any,
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
Expand Down Expand Up @@ -490,7 +492,7 @@ def execute_script(
parameters: Optional["StatementParameterType"] = None,
/,
*,
connection: Optional["Connection"] = None,
connection: Optional["AdbcConnection"] = None,
**kwargs: Any,
) -> str:
"""Execute a script.
Expand All @@ -513,7 +515,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType]
parameters: "Optional[StatementParameterType]" = None,
/,
*,
connection: "Optional[Connection]" = None,
connection: "Optional[AdbcConnection]" = None,
**kwargs: Any,
) -> "ArrowTable":
"""Execute a SQL query and return results as an Apache Arrow Table.
Expand Down
3 changes: 2 additions & 1 deletion sqlspec/adapters/aiosqlite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sqlspec.adapters.aiosqlite.config import AiosqliteConfig
from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver
from sqlspec.adapters.aiosqlite.driver import AiosqliteConnection, AiosqliteDriver

__all__ = (
"AiosqliteConfig",
"AiosqliteConnection",
"AiosqliteDriver",
)
Loading