Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ exclude_lines = [

[tool.pytest.ini_options]
addopts = "-ra -q --doctest-glob='*.md' --strict-markers --strict-config"
asyncio_default_fixture_loop_scope = "function"
asyncio_mode = "auto"
testpaths = ["tests"]
xfail_strict = true

Expand Down
28 changes: 27 additions & 1 deletion sqlspec/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,26 @@

from __future__ import annotations

from enum import Enum
from typing import (
Any,
ClassVar,
Final,
Protocol,
Union,
cast,
runtime_checkable,
)

from typing_extensions import TypeVar, dataclass_transform
from typing_extensions import Literal, TypeVar, dataclass_transform


@runtime_checkable
class DataclassProtocol(Protocol):
"""Protocol for instance checking dataclasses."""

__dataclass_fields__: ClassVar[dict[str, Any]]


T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
Expand Down Expand Up @@ -99,11 +110,26 @@ class UnsetType(enum.Enum): # type: ignore[no-redef]
UNSET = UnsetType.UNSET # pyright: ignore[reportConstantRedefinition]
MSGSPEC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]


class EmptyEnum(Enum):
"""A sentinel enum used as placeholder."""

EMPTY = 0


EmptyType = Union[Literal[EmptyEnum.EMPTY], UnsetType]
Empty: Final = EmptyEnum.EMPTY


__all__ = (
"MSGSPEC_INSTALLED",
"PYDANTIC_INSTALLED",
"UNSET",
"BaseModel",
"DataclassProtocol",
"Empty",
"EmptyEnum",
"EmptyType",
"FailFast",
"Struct",
"TypeAdapter",
Expand Down
4 changes: 1 addition & 3 deletions sqlspec/adapters/adbc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
from typing import TYPE_CHECKING, TypeVar

from sqlspec.config import GenericDatabaseConfig
from sqlspec.utils.empty import Empty
from sqlspec.typing import Empty, EmptyType

if TYPE_CHECKING:
from collections.abc import Generator
from typing import Any

from adbc_driver_manager.dbapi import Connection, Cursor

from sqlspec.utils.empty import EmptyType

__all__ = ("AdbcDatabaseConfig",)

ConnectionT = TypeVar("ConnectionT", bound="Connection")
Expand Down
5 changes: 2 additions & 3 deletions sqlspec/adapters/aiosqlite/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from sqlspec.config import GenericDatabaseConfig
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.utils.dataclass import simple_asdict
from sqlspec.utils.empty import Empty, EmptyType
from sqlspec.typing import Empty, EmptyType, dataclass_to_dict

if TYPE_CHECKING:
from collections.abc import AsyncGenerator
Expand Down Expand Up @@ -60,7 +59,7 @@ def connection_config_dict(self) -> dict[str, Any]:
Returns:
A string keyed dict of config kwargs for the aiosqlite.connect() function.
"""
return simple_asdict(self, exclude_empty=True, convert_nested=False)
return dataclass_to_dict(self, exclude_empty=True, convert_nested=False)

async def create_connection(self) -> Connection:
"""Create and return a new database connection.
Expand Down
7 changes: 3 additions & 4 deletions sqlspec/adapters/asyncmy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from typing import TYPE_CHECKING, TypeVar

from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.utils.dataclass import simple_asdict
from sqlspec.utils.empty import Empty, EmptyType
from sqlspec.typing import Empty, EmptyType, dataclass_to_dict

if TYPE_CHECKING:
from collections.abc import AsyncGenerator
Expand Down Expand Up @@ -101,7 +100,7 @@ def pool_config_dict(self) -> dict[str, Any]:
Returns:
A string keyed dict of config kwargs for the Asyncmy create_pool function.
"""
return simple_asdict(self, exclude_empty=True, convert_nested=False)
return dataclass_to_dict(self, exclude_empty=True, convert_nested=False)


@dataclass
Expand All @@ -125,7 +124,7 @@ def pool_config_dict(self) -> dict[str, Any]:
A string keyed dict of config kwargs for the Asyncmy create_pool function.
"""
if self.pool_config:
return simple_asdict(self.pool_config, exclude_empty=True, convert_nested=False)
return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False)
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
raise ImproperConfigurationError(msg)

Expand Down
9 changes: 4 additions & 5 deletions sqlspec/adapters/asyncpg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from sqlspec._serialization import decode_json, encode_json
from sqlspec.config import GenericDatabaseConfig, GenericPoolConfig
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.utils.dataclass import simple_asdict
from sqlspec.utils.empty import Empty, EmptyType
from sqlspec.typing import Empty, EmptyType, dataclass_to_dict

if TYPE_CHECKING:
from asyncio import AbstractEventLoop
Expand Down Expand Up @@ -98,7 +97,7 @@ def pool_config_dict(self) -> dict[str, Any]:
function.
"""
if self.pool_config:
return simple_asdict(self.pool_config, exclude_empty=True, convert_nested=False)
return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False)
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
raise ImproperConfigurationError(msg)

Expand All @@ -125,15 +124,15 @@ async def create_pool(self) -> Pool:
return self.pool_instance

@asynccontextmanager
async def lifespan(self, *args: Any, **kwargs) -> AsyncGenerator[None, None]:
async def lifespan(self, *args: Any, **kwargs: Any) -> AsyncGenerator[None, None]:
db_pool = await self.create_pool()
try:
yield
finally:
db_pool.terminate()
await db_pool.close()

def provide_pool(self, *args: Any, **kwargs) -> Awaitable[Pool]:
def provide_pool(self, *args: Any, **kwargs: Any) -> Awaitable[Pool]:
"""Create a pool instance.

Returns:
Expand Down
136 changes: 126 additions & 10 deletions sqlspec/adapters/duckdb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,66 @@

from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from sqlspec.config import GenericDatabaseConfig
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.utils.dataclass import simple_asdict
from sqlspec.utils.empty import Empty, EmptyType
from sqlspec.typing import Empty, EmptyType, dataclass_to_dict

if TYPE_CHECKING:
from collections.abc import Generator
from collections.abc import Generator, Sequence

from duckdb import DuckDBPyConnection

__all__ = ("DuckDBConfig",)
__all__ = ("DuckDBConfig", "ExtensionConfig")


@dataclass
class ExtensionConfig:
"""Configuration for a DuckDB extension.

This class provides configuration options for DuckDB extensions, including installation
and post-install configuration settings.

Args:
name: The name of the extension to install
config: Optional configuration settings to apply after installation
force_install: Whether to force reinstall if already present
repository: Optional repository name to install from
repository_url: Optional repository URL to install from
version: Optional version of the extension to install
"""

name: str
config: dict[str, Any] | None = None
force_install: bool = False
repository: str | None = None
repository_url: str | None = None
version: str | None = None

@classmethod
def from_dict(cls, name: str, config: dict[str, Any] | bool | None = None) -> ExtensionConfig:
"""Create an ExtensionConfig from a configuration dictionary.

Args:
name: The name of the extension
config: Configuration dictionary that may contain settings

Returns:
A new ExtensionConfig instance
"""
if config is None:
return cls(name=name)

if not isinstance(config, dict):
config = {"force_install": bool(config)}

install_args = {
key: config.pop(key)
for key in ["force_install", "repository", "repository_url", "version", "config", "name"]
if key in config
}
return cls(name=name, **install_args)


@dataclass
Expand All @@ -39,31 +86,100 @@ class DuckDBConfig(GenericDatabaseConfig):
For details see: https://duckdb.org/docs/api/python/overview#connection-options
"""

extensions: Sequence[ExtensionConfig] | EmptyType = Empty
"""A sequence of extension configurations to install and configure upon connection creation."""

def __post_init__(self) -> None:
"""Post-initialization validation and processing.

This method handles merging extension configurations from both the extensions field
and the config dictionary, if present. The config['extensions'] field can be either:
- A dictionary mapping extension names to their configurations
- A list of extension names (which will be installed with force_install=True)

Raises:
ImproperConfigurationError: If there are duplicate extension configurations.
"""
if self.config is Empty:
self.config = {}

if self.extensions is Empty:
self.extensions = []
# this is purely for mypy
assert isinstance(self.config, dict) # noqa: S101
assert isinstance(self.extensions, list) # noqa: S101

_e = self.config.pop("extensions", {})
if not isinstance(_e, (dict, list, tuple)):
msg = "When configuring extensions in the 'config' dictionary, the value must be a dictionary or sequence of extension names"
raise ImproperConfigurationError(msg)
if not isinstance(_e, dict):
_e = {str(ext): {"force_install": False} for ext in _e}

if len(set(_e.keys()).intersection({ext.name for ext in self.extensions})) > 0:
msg = "Configuring the same extension in both 'extensions' and as a key in 'config['extensions']' is not allowed"
raise ImproperConfigurationError(msg)

self.extensions.extend([ExtensionConfig.from_dict(name, ext_config) for name, ext_config in _e.items()])

def _configure_extensions(self, connection: DuckDBPyConnection) -> None:
"""Configure extensions for the connection.

Args:
connection: The DuckDB connection to configure extensions for.

Raises:
ImproperConfigurationError: If extension installation or configuration fails.
"""
if self.extensions is Empty:
return

for extension in cast("list[ExtensionConfig]", self.extensions):
try:
if extension.force_install:
connection.install_extension(
extension=extension.name,
force_install=extension.force_install,
repository=extension.repository,
repository_url=extension.repository_url,
version=extension.version,
)
connection.load_extension(extension.name)

if extension.config:
for key, value in extension.config.items():
connection.execute(f"SET {key}={value}")
except Exception as e:
msg = f"Failed to configure extension {extension.name}. Error: {e!s}"
raise ImproperConfigurationError(msg) from e

@property
def connection_config_dict(self) -> dict[str, Any]:
"""Return the connection configuration as a dict.

Returns:
A string keyed dict of config kwargs for the duckdb.connect() function.
"""
config = simple_asdict(self, exclude_empty=True, convert_nested=False)
config = dataclass_to_dict(self, exclude_empty=True, exclude={"extensions"}, convert_nested=False)
if not config.get("database"):
config["database"] = ":memory:"
return config

def create_connection(self) -> DuckDBPyConnection:
"""Create and return a new database connection.
"""Create and return a new database connection with configured extensions.

Returns:
A new DuckDB connection instance.
A new DuckDB connection instance with extensions installed and configured.

Raises:
ImproperConfigurationError: If the connection could not be established.
ImproperConfigurationError: If the connection could not be established or extensions could not be configured.
"""
import duckdb

try:
return duckdb.connect(**self.connection_config_dict)
connection = duckdb.connect(**self.connection_config_dict)
self._configure_extensions(connection)
return connection
except Exception as e:
msg = f"Could not configure the DuckDB connection. Error: {e!s}"
raise ImproperConfigurationError(msg) from e
Expand Down
13 changes: 9 additions & 4 deletions sqlspec/adapters/oracledb/config/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
OracleGenericPoolConfig,
)
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.utils.dataclass import simple_asdict
from sqlspec.typing import dataclass_to_dict

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Awaitable
Expand All @@ -36,6 +36,11 @@ class OracleAsyncDatabaseConfig(OracleGenericDatabaseConfig[AsyncConnectionPool,

pool_config: OracleAsyncPoolConfig | None = None
"""Oracle Pool configuration"""
pool_instance: AsyncConnectionPool | None = None
"""Optional pool to use.

If set, the plugin will use the provided pool rather than instantiate one.
"""

@property
def pool_config_dict(self) -> dict[str, Any]:
Expand All @@ -46,7 +51,7 @@ def pool_config_dict(self) -> dict[str, Any]:
function.
"""
if self.pool_config is not None:
return simple_asdict(self.pool_config, exclude_empty=True, convert_nested=False)
return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False)
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
raise ImproperConfigurationError(msg)

Expand All @@ -71,14 +76,14 @@ async def create_pool(self) -> AsyncConnectionPool:
return self.pool_instance

@asynccontextmanager
async def lifespan(self, *args: Any, **kwargs) -> AsyncGenerator[None, None]:
async def lifespan(self, *args: Any, **kwargs: Any) -> AsyncGenerator[None, None]:
db_pool = await self.create_pool()
try:
yield
finally:
await db_pool.close(force=True)

def provide_pool(self, *args: Any, **kwargs) -> Awaitable[AsyncConnectionPool]:
def provide_pool(self, *args: Any, **kwargs: Any) -> Awaitable[AsyncConnectionPool]:
"""Create a pool instance.

Returns:
Expand Down
Loading
Loading