Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ clean: ## Cleanup temporary build a
.PHONY: test
test: ## Run the tests
@echo "${INFO} Running test cases... 🧪"
@uv run pytest tests
@uv run pytest -n 2 --dist=loadgroup tests
@echo "${OK} Tests complete ✨"

.PHONY: test-all
Expand All @@ -128,7 +128,7 @@ test-all: tests ## Run all tests
.PHONY: coverage
coverage: ## Run tests with coverage report
@echo "${INFO} Running tests with coverage... 📊"
@uv run pytest --cov -n auto --quiet
@uv run pytest --cov -n 2 --dist=loadgroup --quiet
@uv run coverage html >/dev/null 2>&1
@uv run coverage xml >/dev/null 2>&1
@echo "${OK} Coverage report generated ✨"
Expand Down
29 changes: 27 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ exclude_lines = [
]

[tool.pytest.ini_options]
addopts = "-ra -q --doctest-glob='*.md' --strict-markers --strict-config"
addopts = ["-q", "-ra"]
asyncio_default_fixture_loop_scope = "function"
asyncio_mode = "auto"
filterwarnings = [
Expand All @@ -189,8 +189,31 @@ filterwarnings = [
"ignore::DeprecationWarning:websockets.connection",
"ignore::DeprecationWarning:websockets.legacy",
]
markers = [
"integration: marks tests that require an external database",
"postgres: marks tests specific to PostgreSQL",
"duckdb: marks tests specific to DuckDB",
"sqlite: marks tests specific to SQLite",
"bigquery: marks tests specific to Google BigQuery",
"mysql: marks tests specific to MySQL",
"oracle: marks tests specific to Oracle",
"spanner: marks tests specific to Google Cloud Spanner",
"mssql: marks tests specific to Microsoft SQL Server",
# Driver markers
"adbc: marks tests using ADBC drivers",
"aioodbc: marks tests using aioodbc",
"aiosqlite: marks tests using aiosqlite",
"asyncmy: marks tests using asyncmy",
"asyncpg: marks tests using asyncpg",
"duckdb_driver: marks tests using the duckdb driver",
"google_bigquery: marks tests using google-cloud-bigquery",
"google_spanner: marks tests using google-cloud-spanner",
"oracledb: marks tests using oracledb",
"psycopg: marks tests using psycopg",
"pymssql: marks tests using pymssql",
"pymysql: marks tests using pymysql",
]
testpaths = ["tests"]
xfail_strict = true

[tool.mypy]
packages = ["sqlspec", "tests"]
Expand Down Expand Up @@ -220,6 +243,8 @@ module = [
"uvloop.*",
"asyncmy",
"asyncmy.*",
"pyarrow",
"pyarrow.*",
]

[tool.pyright]
Expand Down
45 changes: 39 additions & 6 deletions sqlspec/_typing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# ruff: noqa: RUF100, PLR0913, A002, DOC201, PLR6301
"""This is a simple wrapper around a few important classes in each library.

This is used to ensure compatibility when one or more of the libraries are installed.
"""

from collections.abc import Iterable, Mapping
from enum import Enum
from typing import (
Any,
Expand Down Expand Up @@ -96,7 +98,7 @@ def __init__(

def validate_python(
self,
object: Any, # noqa: A002
object: Any,
/,
*,
strict: "Optional[bool]" = None,
Expand Down Expand Up @@ -127,10 +129,7 @@ class FailFast: # type: ignore[no-redef]
except ImportError:
import enum
from collections.abc import Iterable
from typing import TYPE_CHECKING, Callable, Optional, Union

if TYPE_CHECKING:
from collections.abc import Iterable
from typing import Callable, Optional, Union

@dataclass_transform()
@runtime_checkable
Expand Down Expand Up @@ -174,7 +173,6 @@ def __init__(self, backend: Any, data_as_builtins: Any) -> None:
"""Placeholder init"""

def create_instance(self, **kwargs: Any) -> "T":
"""Placeholder implementation"""
return cast("T", kwargs)

def update_instance(self, instance: "T", **kwargs: Any) -> "T":
Expand All @@ -198,11 +196,46 @@ class EmptyEnum(Enum):
Empty: Final = EmptyEnum.EMPTY


try:
from pyarrow import Table as ArrowTable

PYARROW_INSTALLED = True
except ImportError:

@runtime_checkable
class ArrowTable(Protocol): # type: ignore[no-redef]
"""Placeholder Implementation"""

def to_batches(self, batch_size: int) -> Any: ...
def num_rows(self) -> int: ...
def num_columns(self) -> int: ...
def to_pydict(self) -> dict[str, Any]: ...
def to_string(self) -> str: ...
def from_arrays(
self,
arrays: list[Any],
names: "Optional[list[str]]" = None,
schema: "Optional[Any]" = None,
metadata: "Optional[Mapping[str, Any]]" = None,
) -> Any: ...
def from_pydict(
self,
mapping: dict[str, Any],
schema: "Optional[Any]" = None,
metadata: "Optional[Mapping[str, Any]]" = None,
) -> Any: ...
def from_batches(self, batches: Iterable[Any], schema: Optional[Any] = None) -> Any: ...

PYARROW_INSTALLED = False # pyright: ignore[reportConstantRedefinition]


__all__ = (
"LITESTAR_INSTALLED",
"MSGSPEC_INSTALLED",
"PYARROW_INSTALLED",
"PYDANTIC_INSTALLED",
"UNSET",
"ArrowTable",
"BaseModel",
"DTOData",
"DataclassProtocol",
Expand Down
4 changes: 2 additions & 2 deletions sqlspec/adapters/adbc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sqlspec.adapters.adbc.config import Adbc
from sqlspec.adapters.adbc.config import AdbcConfig
from sqlspec.adapters.adbc.driver import AdbcDriver

__all__ = (
"Adbc",
"AdbcConfig",
"AdbcDriver",
)
45 changes: 34 additions & 11 deletions sqlspec/adapters/adbc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from collections.abc import Generator


__all__ = ("Adbc",)
__all__ = ("AdbcConfig",)


@dataclass
class Adbc(NoPoolSyncConfig["Connection", "AdbcDriver"]):
class AdbcConfig(NoPoolSyncConfig["Connection", "AdbcDriver"]):
"""Configuration for ADBC connections.

This class provides configuration options for ADBC database connections using the
Expand Down Expand Up @@ -55,17 +55,41 @@ def _set_adbc(self) -> str: # noqa: PLR0912
"""

if isinstance(self.driver_name, str):
if self.driver_name != "adbc_driver_sqlite.dbapi.connect" and "sqlite" in self.driver_name:
if self.driver_name != "adbc_driver_sqlite.dbapi.connect" and self.driver_name in {
"sqlite",
"sqlite3",
"adbc_driver_sqlite",
}:
self.driver_name = "adbc_driver_sqlite.dbapi.connect"
elif self.driver_name != "adbc_driver_duckdb.dbapi.connect" and "duckdb" in self.driver_name:
elif self.driver_name != "adbc_driver_duckdb.dbapi.connect" and self.driver_name in {
"duckdb",
"adbc_driver_duckdb",
}:
self.driver_name = "adbc_driver_duckdb.dbapi.connect"
elif self.driver_name != "adbc_driver_postgresql.dbapi.connect" and "postgres" in self.driver_name:
elif self.driver_name != "adbc_driver_postgresql.dbapi.connect" and self.driver_name in {
"postgres",
"adbc_driver_postgresql",
"postgresql",
"pg",
}:
self.driver_name = "adbc_driver_postgresql.dbapi.connect"
elif self.driver_name != "adbc_driver_snowflake.dbapi.connect" and "snowflake" in self.driver_name:
elif self.driver_name != "adbc_driver_snowflake.dbapi.connect" and self.driver_name in {
"snowflake",
"adbc_driver_snowflake",
"sf",
}:
self.driver_name = "adbc_driver_snowflake.dbapi.connect"
elif self.driver_name != "adbc_driver_bigquery.dbapi.connect" and "bigquery" in self.driver_name:
elif self.driver_name != "adbc_driver_bigquery.dbapi.connect" and self.driver_name in {
"bigquery",
"adbc_driver_bigquery",
"bq",
}:
self.driver_name = "adbc_driver_bigquery.dbapi.connect"
elif self.driver_name != "adbc_driver_flightsql.dbapi.connect" and "flightsql" in self.driver_name:
elif self.driver_name != "adbc_driver_flightsql.dbapi.connect" and self.driver_name in {
"flightsql",
"adbc_driver_flightsql",
"grpc",
}:
self.driver_name = "adbc_driver_flightsql.dbapi.connect"
return self.driver_name

Expand Down Expand Up @@ -153,11 +177,10 @@ def create_connection(self) -> "Connection":
"""
try:
connect_func = self._get_connect_func()
_config = self.connection_config_dict
return connect_func(**_config)
return connect_func(**self.connection_config_dict)
except Exception as e:
# Include driver name in error message for better context
driver_name = self.driver_name if isinstance(self.driver_name, str) else "Unknown/Derived"
driver_name = self.driver_name if isinstance(self.driver_name, str) else "Unknown/Missing"
# Use the potentially modified driver_path from _get_connect_func if available,
# otherwise fallback to self.driver_name for the error message.
# This requires _get_connect_func to potentially return the used path or store it.
Expand Down
40 changes: 32 additions & 8 deletions sqlspec/adapters/adbc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
import re
from collections.abc import Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast

from adbc_driver_manager.dbapi import Connection, Cursor
from adbc_driver_manager.dbapi import Connection
from adbc_driver_manager.dbapi import Cursor as DbapiCursor

from sqlspec.base import SyncDriverAdapterProtocol, T
from sqlspec._typing import ArrowTable
from sqlspec.base import SyncArrowBulkOperationsMixin, SyncDriverAdapterProtocol, T

if TYPE_CHECKING:
from sqlspec.typing import ModelDTOT, StatementParameterType
from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType

__all__ = ("AdbcDriver",)

Expand All @@ -26,10 +28,11 @@
)


class AdbcDriver(SyncDriverAdapterProtocol["Connection"]):
class AdbcDriver(SyncArrowBulkOperationsMixin["Connection"], SyncDriverAdapterProtocol["Connection"]):
"""ADBC Sync Driver Adapter."""

connection: Connection
__supports_arrow__: ClassVar[bool] = True

def __init__(self, connection: "Connection") -> None:
"""Initialize the ADBC driver adapter."""
Expand All @@ -38,12 +41,12 @@ def __init__(self, connection: "Connection") -> None:
# For now, assume 'qmark' based on typical ADBC DBAPI behavior

@staticmethod
def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "Cursor":
def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "DbapiCursor":
return connection.cursor(*args, **kwargs)

@contextmanager
def _with_cursor(self, connection: "Connection") -> Generator["Cursor", None, None]:
cursor = self._cursor(connection)
def _with_cursor(self, connection: "Connection") -> Generator["DbapiCursor", None, None]:
cursor: DbapiCursor = self._cursor(connection)
try:
yield cursor
finally:
Expand Down Expand Up @@ -331,3 +334,24 @@ def execute_script_returning(
if schema_type is not None:
return cast("ModelDTOT", schema_type(**dict(zip(column_names, result[0])))) # pyright: ignore[reportUnknownArgumentType]
return dict(zip(column_names, result[0])) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]

# --- Arrow Bulk Operations ---

def select_arrow( # pyright: ignore[reportUnknownParameterType]
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
connection: "Optional[Connection]" = None,
) -> "ArrowTable":
"""Execute a SQL query and return results as an Apache Arrow Table.

Returns:
The results of the query as an Apache Arrow Table.
"""
conn = self._connection(connection)
sql, parameters = self._process_sql_params(sql, parameters)

with self._with_cursor(conn) as cursor:
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
return cast("ArrowTable", cursor.fetch_arrow_table()) # pyright: ignore[reportUnknownMemberType]
4 changes: 2 additions & 2 deletions sqlspec/adapters/aiosqlite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sqlspec.adapters.aiosqlite.config import Aiosqlite
from sqlspec.adapters.aiosqlite.config import AiosqliteConfig
from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver

__all__ = (
"Aiosqlite",
"AiosqliteConfig",
"AiosqliteDriver",
)
4 changes: 2 additions & 2 deletions sqlspec/adapters/aiosqlite/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from typing import Literal


__all__ = ("Aiosqlite",)
__all__ = ("AiosqliteConfig",)


@dataclass
class Aiosqlite(NoPoolAsyncConfig["Connection", "AiosqliteDriver"]):
class AiosqliteConfig(NoPoolAsyncConfig["Connection", "AiosqliteDriver"]):
"""Configuration for Aiosqlite database connections.

This class provides configuration options for Aiosqlite database connections, wrapping all parameters
Expand Down
6 changes: 3 additions & 3 deletions sqlspec/adapters/asyncmy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from sqlspec.adapters.asyncmy.config import Asyncmy, AsyncmyPool
from sqlspec.adapters.asyncmy.config import AsyncmyConfig, AsyncmyPoolConfig
from sqlspec.adapters.asyncmy.driver import AsyncmyDriver # type: ignore[attr-defined]

__all__ = (
"Asyncmy",
"AsyncmyConfig",
"AsyncmyDriver",
"AsyncmyPool",
"AsyncmyPoolConfig",
)
16 changes: 8 additions & 8 deletions sqlspec/adapters/asyncmy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
from asyncmy.pool import Pool # pyright: ignore[reportUnknownVariableType]

__all__ = (
"Asyncmy",
"AsyncmyPool",
"AsyncmyConfig",
"AsyncmyPoolConfig",
)


T = TypeVar("T")


@dataclass
class AsyncmyPool(GenericPoolConfig):
class AsyncmyPoolConfig(GenericPoolConfig):
"""Configuration for Asyncmy's connection pool.

This class provides configuration options for Asyncmy database connection pools.
Expand Down Expand Up @@ -104,19 +104,19 @@ def pool_config_dict(self) -> "dict[str, Any]":


@dataclass
class Asyncmy(AsyncDatabaseConfig["Connection", "Pool", "AsyncmyDriver"]):
class AsyncmyConfig(AsyncDatabaseConfig["Connection", "Pool", "AsyncmyDriver"]):
"""Asyncmy Configuration."""

__is_async__ = True
__supports_connection_pooling__ = True

pool_config: "Optional[AsyncmyPool]" = None
pool_config: "Optional[AsyncmyPoolConfig]" = None
"""Asyncmy Pool configuration"""
connection_type: "type[Connection]" = field(init=False, default_factory=lambda: Connection) # pyright: ignore
connection_type: "type[Connection]" = field(hash=False, init=False, default_factory=lambda: Connection) # pyright: ignore
"""Type of the connection object"""
driver_type: "type[AsyncmyDriver]" = field(init=False, default_factory=lambda: AsyncmyDriver)
driver_type: "type[AsyncmyDriver]" = field(hash=False, init=False, default_factory=lambda: AsyncmyDriver)
"""Type of the driver object"""
pool_instance: "Optional[Pool]" = None # pyright: ignore[reportUnknownVariableType]
pool_instance: "Optional[Pool]" = field(hash=False, default=None) # pyright: ignore[reportUnknownVariableType]
"""Instance of the pool"""

@property
Expand Down
Loading