diff --git a/pyproject.toml b/pyproject.toml index 48d4d6174..b498ef77f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,6 +128,8 @@ exclude_lines = [ [tool.pytest.ini_options] addopts = "-ra -q --doctest-glob='*.md' --strict-markers --strict-config" +asyncio_default_fixture_loop_scope = "function" +asyncio_mode = "auto" testpaths = ["tests"] xfail_strict = true diff --git a/sqlspec/_typing.py b/sqlspec/_typing.py index f56c38ecd..61ad7c9ca 100644 --- a/sqlspec/_typing.py +++ b/sqlspec/_typing.py @@ -5,15 +5,26 @@ from __future__ import annotations +from enum import Enum from typing import ( Any, ClassVar, + Final, Protocol, + Union, cast, runtime_checkable, ) -from typing_extensions import TypeVar, dataclass_transform +from typing_extensions import Literal, TypeVar, dataclass_transform + + +@runtime_checkable +class DataclassProtocol(Protocol): + """Protocol for instance checking dataclasses.""" + + __dataclass_fields__: ClassVar[dict[str, Any]] + T = TypeVar("T") T_co = TypeVar("T_co", covariant=True) @@ -99,11 +110,26 @@ class UnsetType(enum.Enum): # type: ignore[no-redef] UNSET = UnsetType.UNSET # pyright: ignore[reportConstantRedefinition] MSGSPEC_INSTALLED = False # pyright: ignore[reportConstantRedefinition] + +class EmptyEnum(Enum): + """A sentinel enum used as placeholder.""" + + EMPTY = 0 + + +EmptyType = Union[Literal[EmptyEnum.EMPTY], UnsetType] +Empty: Final = EmptyEnum.EMPTY + + __all__ = ( "MSGSPEC_INSTALLED", "PYDANTIC_INSTALLED", "UNSET", "BaseModel", + "DataclassProtocol", + "Empty", + "EmptyEnum", + "EmptyType", "FailFast", "Struct", "TypeAdapter", diff --git a/sqlspec/adapters/adbc/config.py b/sqlspec/adapters/adbc/config.py index 0c3e56153..9ed4ddc6a 100644 --- a/sqlspec/adapters/adbc/config.py +++ b/sqlspec/adapters/adbc/config.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, TypeVar from sqlspec.config import GenericDatabaseConfig -from sqlspec.utils.empty import Empty +from sqlspec.typing import Empty, EmptyType if TYPE_CHECKING: from collections.abc import Generator @@ -13,8 +13,6 @@ from adbc_driver_manager.dbapi import Connection, Cursor - from sqlspec.utils.empty import EmptyType - __all__ = ("AdbcDatabaseConfig",) ConnectionT = TypeVar("ConnectionT", bound="Connection") diff --git a/sqlspec/adapters/aiosqlite/config.py b/sqlspec/adapters/aiosqlite/config.py index 1067c6966..05a891d59 100644 --- a/sqlspec/adapters/aiosqlite/config.py +++ b/sqlspec/adapters/aiosqlite/config.py @@ -6,8 +6,7 @@ from sqlspec.config import GenericDatabaseConfig from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.utils.dataclass import simple_asdict -from sqlspec.utils.empty import Empty, EmptyType +from sqlspec.typing import Empty, EmptyType, dataclass_to_dict if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -60,7 +59,7 @@ def connection_config_dict(self) -> dict[str, Any]: Returns: A string keyed dict of config kwargs for the aiosqlite.connect() function. """ - return simple_asdict(self, exclude_empty=True, convert_nested=False) + return dataclass_to_dict(self, exclude_empty=True, convert_nested=False) async def create_connection(self) -> Connection: """Create and return a new database connection. diff --git a/sqlspec/adapters/asyncmy/config.py b/sqlspec/adapters/asyncmy/config.py index b27a2a117..e725d164e 100644 --- a/sqlspec/adapters/asyncmy/config.py +++ b/sqlspec/adapters/asyncmy/config.py @@ -5,8 +5,7 @@ from typing import TYPE_CHECKING, TypeVar from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.utils.dataclass import simple_asdict -from sqlspec.utils.empty import Empty, EmptyType +from sqlspec.typing import Empty, EmptyType, dataclass_to_dict if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -101,7 +100,7 @@ def pool_config_dict(self) -> dict[str, Any]: Returns: A string keyed dict of config kwargs for the Asyncmy create_pool function. """ - return simple_asdict(self, exclude_empty=True, convert_nested=False) + return dataclass_to_dict(self, exclude_empty=True, convert_nested=False) @dataclass @@ -125,7 +124,7 @@ def pool_config_dict(self) -> dict[str, Any]: A string keyed dict of config kwargs for the Asyncmy create_pool function. """ if self.pool_config: - return simple_asdict(self.pool_config, exclude_empty=True, convert_nested=False) + return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False) msg = "'pool_config' methods can not be used when a 'pool_instance' is provided." raise ImproperConfigurationError(msg) diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index f2d241375..71a41b2cd 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -10,8 +10,7 @@ from sqlspec._serialization import decode_json, encode_json from sqlspec.config import GenericDatabaseConfig, GenericPoolConfig from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.utils.dataclass import simple_asdict -from sqlspec.utils.empty import Empty, EmptyType +from sqlspec.typing import Empty, EmptyType, dataclass_to_dict if TYPE_CHECKING: from asyncio import AbstractEventLoop @@ -98,7 +97,7 @@ def pool_config_dict(self) -> dict[str, Any]: function. """ if self.pool_config: - return simple_asdict(self.pool_config, exclude_empty=True, convert_nested=False) + return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False) msg = "'pool_config' methods can not be used when a 'pool_instance' is provided." raise ImproperConfigurationError(msg) @@ -125,7 +124,7 @@ async def create_pool(self) -> Pool: return self.pool_instance @asynccontextmanager - async def lifespan(self, *args: Any, **kwargs) -> AsyncGenerator[None, None]: + async def lifespan(self, *args: Any, **kwargs: Any) -> AsyncGenerator[None, None]: db_pool = await self.create_pool() try: yield @@ -133,7 +132,7 @@ async def lifespan(self, *args: Any, **kwargs) -> AsyncGenerator[None, None]: db_pool.terminate() await db_pool.close() - def provide_pool(self, *args: Any, **kwargs) -> Awaitable[Pool]: + def provide_pool(self, *args: Any, **kwargs: Any) -> Awaitable[Pool]: """Create a pool instance. Returns: diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index 8f171a087..ee10b8e16 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -2,19 +2,66 @@ from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from sqlspec.config import GenericDatabaseConfig from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.utils.dataclass import simple_asdict -from sqlspec.utils.empty import Empty, EmptyType +from sqlspec.typing import Empty, EmptyType, dataclass_to_dict if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Generator, Sequence from duckdb import DuckDBPyConnection -__all__ = ("DuckDBConfig",) +__all__ = ("DuckDBConfig", "ExtensionConfig") + + +@dataclass +class ExtensionConfig: + """Configuration for a DuckDB extension. + + This class provides configuration options for DuckDB extensions, including installation + and post-install configuration settings. + + Args: + name: The name of the extension to install + config: Optional configuration settings to apply after installation + force_install: Whether to force reinstall if already present + repository: Optional repository name to install from + repository_url: Optional repository URL to install from + version: Optional version of the extension to install + """ + + name: str + config: dict[str, Any] | None = None + force_install: bool = False + repository: str | None = None + repository_url: str | None = None + version: str | None = None + + @classmethod + def from_dict(cls, name: str, config: dict[str, Any] | bool | None = None) -> ExtensionConfig: + """Create an ExtensionConfig from a configuration dictionary. + + Args: + name: The name of the extension + config: Configuration dictionary that may contain settings + + Returns: + A new ExtensionConfig instance + """ + if config is None: + return cls(name=name) + + if not isinstance(config, dict): + config = {"force_install": bool(config)} + + install_args = { + key: config.pop(key) + for key in ["force_install", "repository", "repository_url", "version", "config", "name"] + if key in config + } + return cls(name=name, **install_args) @dataclass @@ -39,6 +86,73 @@ class DuckDBConfig(GenericDatabaseConfig): For details see: https://duckdb.org/docs/api/python/overview#connection-options """ + extensions: Sequence[ExtensionConfig] | EmptyType = Empty + """A sequence of extension configurations to install and configure upon connection creation.""" + + def __post_init__(self) -> None: + """Post-initialization validation and processing. + + This method handles merging extension configurations from both the extensions field + and the config dictionary, if present. The config['extensions'] field can be either: + - A dictionary mapping extension names to their configurations + - A list of extension names (which will be installed with force_install=True) + + Raises: + ImproperConfigurationError: If there are duplicate extension configurations. + """ + if self.config is Empty: + self.config = {} + + if self.extensions is Empty: + self.extensions = [] + # this is purely for mypy + assert isinstance(self.config, dict) # noqa: S101 + assert isinstance(self.extensions, list) # noqa: S101 + + _e = self.config.pop("extensions", {}) + if not isinstance(_e, (dict, list, tuple)): + msg = "When configuring extensions in the 'config' dictionary, the value must be a dictionary or sequence of extension names" + raise ImproperConfigurationError(msg) + if not isinstance(_e, dict): + _e = {str(ext): {"force_install": False} for ext in _e} + + if len(set(_e.keys()).intersection({ext.name for ext in self.extensions})) > 0: + msg = "Configuring the same extension in both 'extensions' and as a key in 'config['extensions']' is not allowed" + raise ImproperConfigurationError(msg) + + self.extensions.extend([ExtensionConfig.from_dict(name, ext_config) for name, ext_config in _e.items()]) + + def _configure_extensions(self, connection: DuckDBPyConnection) -> None: + """Configure extensions for the connection. + + Args: + connection: The DuckDB connection to configure extensions for. + + Raises: + ImproperConfigurationError: If extension installation or configuration fails. + """ + if self.extensions is Empty: + return + + for extension in cast("list[ExtensionConfig]", self.extensions): + try: + if extension.force_install: + connection.install_extension( + extension=extension.name, + force_install=extension.force_install, + repository=extension.repository, + repository_url=extension.repository_url, + version=extension.version, + ) + connection.load_extension(extension.name) + + if extension.config: + for key, value in extension.config.items(): + connection.execute(f"SET {key}={value}") + except Exception as e: + msg = f"Failed to configure extension {extension.name}. Error: {e!s}" + raise ImproperConfigurationError(msg) from e + @property def connection_config_dict(self) -> dict[str, Any]: """Return the connection configuration as a dict. @@ -46,24 +160,26 @@ def connection_config_dict(self) -> dict[str, Any]: Returns: A string keyed dict of config kwargs for the duckdb.connect() function. """ - config = simple_asdict(self, exclude_empty=True, convert_nested=False) + config = dataclass_to_dict(self, exclude_empty=True, exclude={"extensions"}, convert_nested=False) if not config.get("database"): config["database"] = ":memory:" return config def create_connection(self) -> DuckDBPyConnection: - """Create and return a new database connection. + """Create and return a new database connection with configured extensions. Returns: - A new DuckDB connection instance. + A new DuckDB connection instance with extensions installed and configured. Raises: - ImproperConfigurationError: If the connection could not be established. + ImproperConfigurationError: If the connection could not be established or extensions could not be configured. """ import duckdb try: - return duckdb.connect(**self.connection_config_dict) + connection = duckdb.connect(**self.connection_config_dict) + self._configure_extensions(connection) + return connection except Exception as e: msg = f"Could not configure the DuckDB connection. Error: {e!s}" raise ImproperConfigurationError(msg) from e diff --git a/sqlspec/adapters/oracledb/config/_asyncio.py b/sqlspec/adapters/oracledb/config/_asyncio.py index ded87d0bb..59dbabf6b 100644 --- a/sqlspec/adapters/oracledb/config/_asyncio.py +++ b/sqlspec/adapters/oracledb/config/_asyncio.py @@ -13,7 +13,7 @@ OracleGenericPoolConfig, ) from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.utils.dataclass import simple_asdict +from sqlspec.typing import dataclass_to_dict if TYPE_CHECKING: from collections.abc import AsyncGenerator, Awaitable @@ -36,6 +36,11 @@ class OracleAsyncDatabaseConfig(OracleGenericDatabaseConfig[AsyncConnectionPool, pool_config: OracleAsyncPoolConfig | None = None """Oracle Pool configuration""" + pool_instance: AsyncConnectionPool | None = None + """Optional pool to use. + + If set, the plugin will use the provided pool rather than instantiate one. + """ @property def pool_config_dict(self) -> dict[str, Any]: @@ -46,7 +51,7 @@ def pool_config_dict(self) -> dict[str, Any]: function. """ if self.pool_config is not None: - return simple_asdict(self.pool_config, exclude_empty=True, convert_nested=False) + return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False) msg = "'pool_config' methods can not be used when a 'pool_instance' is provided." raise ImproperConfigurationError(msg) @@ -71,14 +76,14 @@ async def create_pool(self) -> AsyncConnectionPool: return self.pool_instance @asynccontextmanager - async def lifespan(self, *args: Any, **kwargs) -> AsyncGenerator[None, None]: + async def lifespan(self, *args: Any, **kwargs: Any) -> AsyncGenerator[None, None]: db_pool = await self.create_pool() try: yield finally: await db_pool.close(force=True) - def provide_pool(self, *args: Any, **kwargs) -> Awaitable[AsyncConnectionPool]: + def provide_pool(self, *args: Any, **kwargs: Any) -> Awaitable[AsyncConnectionPool]: """Create a pool instance. Returns: diff --git a/sqlspec/adapters/oracledb/config/_common.py b/sqlspec/adapters/oracledb/config/_common.py index 58c13c4d3..f72698207 100644 --- a/sqlspec/adapters/oracledb/config/_common.py +++ b/sqlspec/adapters/oracledb/config/_common.py @@ -6,7 +6,7 @@ from oracledb import ConnectionPool from sqlspec.config import GenericDatabaseConfig, GenericPoolConfig -from sqlspec.utils.empty import Empty +from sqlspec.typing import Empty if TYPE_CHECKING: import ssl @@ -17,7 +17,7 @@ from oracledb.connection import AsyncConnection, Connection from oracledb.pool import AsyncConnectionPool, ConnectionPool - from sqlspec.utils.empty import EmptyType + from sqlspec.typing import EmptyType __all__ = ( "OracleGenericDatabaseConfig", diff --git a/sqlspec/adapters/oracledb/config/_sync.py b/sqlspec/adapters/oracledb/config/_sync.py index 81566b563..069614642 100644 --- a/sqlspec/adapters/oracledb/config/_sync.py +++ b/sqlspec/adapters/oracledb/config/_sync.py @@ -13,7 +13,7 @@ OracleGenericPoolConfig, ) from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.utils.dataclass import simple_asdict +from sqlspec.typing import dataclass_to_dict if TYPE_CHECKING: from collections.abc import Generator @@ -51,7 +51,7 @@ def pool_config_dict(self) -> dict[str, Any]: function. """ if self.pool_config: - return simple_asdict(self.pool_config, exclude_empty=True, convert_nested=False) + return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False) msg = "'pool_config' methods can not be used when a 'pool_instance' is provided." raise ImproperConfigurationError(msg) diff --git a/sqlspec/adapters/psycopg/config/_async.py b/sqlspec/adapters/psycopg/config/_async.py index 7988e3e23..76d18e174 100644 --- a/sqlspec/adapters/psycopg/config/_async.py +++ b/sqlspec/adapters/psycopg/config/_async.py @@ -12,7 +12,7 @@ PsycoPgGenericPoolConfig, ) from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.utils.dataclass import simple_asdict +from sqlspec.typing import dataclass_to_dict if TYPE_CHECKING: from collections.abc import AsyncGenerator, Awaitable @@ -43,7 +43,7 @@ class PsycoPgAsyncDatabaseConfig(PsycoPgGenericDatabaseConfig[AsyncConnectionPoo def pool_config_dict(self) -> dict[str, Any]: """Return the pool configuration as a dict.""" if self.pool_config: - return simple_asdict(self.pool_config, exclude_empty=True, convert_nested=False) + return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False) msg = "'pool_config' methods can not be used when a 'pool_instance' is provided." raise ImproperConfigurationError(msg) diff --git a/sqlspec/adapters/psycopg/config/_common.py b/sqlspec/adapters/psycopg/config/_common.py index ac8f974f0..9a075891e 100644 --- a/sqlspec/adapters/psycopg/config/_common.py +++ b/sqlspec/adapters/psycopg/config/_common.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Generic, TypeVar from sqlspec.config import GenericDatabaseConfig, GenericPoolConfig -from sqlspec.utils.empty import Empty +from sqlspec.typing import Empty if TYPE_CHECKING: from collections.abc import Callable @@ -13,7 +13,8 @@ from psycopg import AsyncConnection, Connection from psycopg_pool import AsyncConnectionPool, ConnectionPool - from sqlspec.utils.empty import EmptyType + from sqlspec.typing import EmptyType + __all__ = ( "PsycoPgGenericDatabaseConfig", diff --git a/sqlspec/adapters/psycopg/config/_sync.py b/sqlspec/adapters/psycopg/config/_sync.py index 2ee16f49f..7136df4fa 100644 --- a/sqlspec/adapters/psycopg/config/_sync.py +++ b/sqlspec/adapters/psycopg/config/_sync.py @@ -12,7 +12,7 @@ PsycoPgGenericPoolConfig, ) from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.utils.dataclass import simple_asdict +from sqlspec.typing import dataclass_to_dict if TYPE_CHECKING: from collections.abc import Generator @@ -43,7 +43,7 @@ class PsycoPgSyncDatabaseConfig(PsycoPgGenericDatabaseConfig[ConnectionPool, Con def pool_config_dict(self) -> dict[str, Any]: """Return the pool configuration as a dict.""" if self.pool_config: - return simple_asdict(self.pool_config, exclude_empty=True, convert_nested=False) + return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False) msg = "'pool_config' methods can not be used when a 'pool_instance' is provided." raise ImproperConfigurationError(msg) diff --git a/sqlspec/adapters/sqlite/config.py b/sqlspec/adapters/sqlite/config.py index 90c39fc83..4b2baea46 100644 --- a/sqlspec/adapters/sqlite/config.py +++ b/sqlspec/adapters/sqlite/config.py @@ -6,8 +6,7 @@ from sqlspec.config import GenericDatabaseConfig from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.utils.dataclass import simple_asdict -from sqlspec.utils.empty import Empty, EmptyType +from sqlspec.typing import Empty, EmptyType, dataclass_to_dict if TYPE_CHECKING: from collections.abc import Generator @@ -57,7 +56,7 @@ def connection_config_dict(self) -> dict[str, Any]: Returns: A string keyed dict of config kwargs for the sqlite3.connect() function. """ - return simple_asdict(self, exclude_empty=True, convert_nested=False) + return dataclass_to_dict(self, exclude_empty=True, convert_nested=False) def create_connection(self) -> Connection: """Create and return a new database connection. @@ -71,7 +70,7 @@ def create_connection(self) -> Connection: import sqlite3 try: - return sqlite3.connect(**self.connection_config_dict) + return sqlite3.connect(**self.connection_config_dict) # type: ignore[no-any-return,unused-ignore] except Exception as e: msg = f"Could not configure the SQLite connection. Error: {e!s}" raise ImproperConfigurationError(msg) from e diff --git a/sqlspec/filters.py b/sqlspec/filters.py index 5e5366024..a5b9416b1 100644 --- a/sqlspec/filters.py +++ b/sqlspec/filters.py @@ -2,11 +2,11 @@ from __future__ import annotations -from abc import ABC, abstractmethod +from abc import ABC from collections import abc # noqa: TC003 from dataclasses import dataclass from datetime import datetime # noqa: TC003 -from typing import Generic, Literal +from typing import Generic, Literal, Protocol from typing_extensions import TypeVar @@ -27,9 +27,11 @@ T = TypeVar("T") -class StatementFilter(ABC): - @abstractmethod +class StatementFilter(Protocol): + """Protocol for filters that can be appended to a statement.""" + def append_to_statement(self, statement: str) -> str: + """Append the filter to the statement.""" return statement diff --git a/sqlspec/typing.py b/sqlspec/typing.py index c9f1d43d4..286686010 100644 --- a/sqlspec/typing.py +++ b/sqlspec/typing.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Sequence +from dataclasses import Field, fields from functools import lru_cache from typing import ( TYPE_CHECKING, @@ -18,15 +19,22 @@ PYDANTIC_INSTALLED, UNSET, BaseModel, + DataclassProtocol, + Empty, + EmptyType, FailFast, Struct, TypeAdapter, + UnsetType, convert, ) -from sqlspec.utils.dataclass import DataclassProtocol, is_dataclass_instance, simple_asdict if TYPE_CHECKING: - from .filters import StatementFilter + from collections.abc import Iterable + from collections.abc import Set as AbstractSet + + from sqlspec.filters import StatementFilter + PYDANTIC_USE_FAILFAST = False # leave permanently disabled for now @@ -55,6 +63,18 @@ """ +def is_dataclass_instance(obj: Any) -> TypeGuard[DataclassProtocol]: + """Check if an object is a dataclass instance. + + Args: + obj: An object to check. + + Returns: + True if the object is a dataclass instance. + """ + return hasattr(type(obj), "__dataclass_fields__") + + @lru_cache(typed=True) def get_type_adapter(f: type[T]) -> TypeAdapter[T]: """Caches and returns a pydantic type adapter. @@ -224,6 +244,104 @@ def is_msgspec_model_without_field(v: Any, field_name: str) -> TypeGuard[Struct] return not is_msgspec_model_with_field(v, field_name) +def extract_dataclass_fields( + dt: DataclassProtocol, + exclude_none: bool = False, + exclude_empty: bool = False, + include: AbstractSet[str] | None = None, + exclude: AbstractSet[str] | None = None, +) -> tuple[Field[Any], ...]: + """Extract dataclass fields. + + Args: + dt: A dataclass instance. + exclude_none: Whether to exclude None values. + exclude_empty: Whether to exclude Empty values. + include: An iterable of fields to include. + exclude: An iterable of fields to exclude. + + + Returns: + A tuple of dataclass fields. + """ + include = include or set() + exclude = exclude or set() + + if common := (include & exclude): + msg = f"Fields {common} are both included and excluded." + raise ValueError(msg) + + dataclass_fields: Iterable[Field[Any]] = fields(dt) + if exclude_none: + dataclass_fields = (field for field in dataclass_fields if getattr(dt, field.name) is not None) + if exclude_empty: + dataclass_fields = (field for field in dataclass_fields if getattr(dt, field.name) is not Empty) + if include: + dataclass_fields = (field for field in dataclass_fields if field.name in include) + if exclude: + dataclass_fields = (field for field in dataclass_fields if field.name not in exclude) + + return tuple(dataclass_fields) + + +def extract_dataclass_items( + dt: DataclassProtocol, + exclude_none: bool = False, + exclude_empty: bool = False, + include: AbstractSet[str] | None = None, + exclude: AbstractSet[str] | None = None, +) -> tuple[tuple[str, Any], ...]: + """Extract dataclass name, value pairs. + + Unlike the 'asdict' method exports by the stdlib, this function does not pickle values. + + Args: + dt: A dataclass instance. + exclude_none: Whether to exclude None values. + exclude_empty: Whether to exclude Empty values. + include: An iterable of fields to include. + exclude: An iterable of fields to exclude. + + Returns: + A tuple of key/value pairs. + """ + dataclass_fields = extract_dataclass_fields(dt, exclude_none, exclude_empty, include, exclude) + return tuple((field.name, getattr(dt, field.name)) for field in dataclass_fields) + + +def dataclass_to_dict( + obj: DataclassProtocol, + exclude_none: bool = False, + exclude_empty: bool = False, + convert_nested: bool = True, + exclude: set[str] | None = None, +) -> dict[str, Any]: + """Convert a dataclass to a dictionary. + + This method has important differences to the standard library version: + - it does not deepcopy values + - it does not recurse into collections + + Args: + obj: A dataclass instance. + exclude_none: Whether to exclude None values. + exclude_empty: Whether to exclude Empty values. + convert_nested: Whether to recursively convert nested dataclasses. + exclude: An iterable of fields to exclude. + + Returns: + A dictionary of key/value pairs. + """ + ret = {} + for field in extract_dataclass_fields(obj, exclude_none, exclude_empty, exclude=exclude): + value = getattr(obj, field.name) + if is_dataclass_instance(value) and convert_nested: + ret[field.name] = dataclass_to_dict(value, exclude_none, exclude_empty) + else: + ret[field.name] = getattr(obj, field.name) + return ret + + def schema_dump( data: dict[str, Any] | Struct | BaseModel | DataclassProtocol, exclude_unset: bool = True, @@ -238,7 +356,7 @@ def schema_dump( :type: dict[str, Any] """ if is_dataclass(data): - return simple_asdict(data, exclude_empty=exclude_unset) + return dataclass_to_dict(data, exclude_empty=exclude_unset) if is_pydantic_model(data): return data.model_dump(exclude_unset=exclude_unset) if is_msgspec_model(data) and exclude_unset: @@ -254,6 +372,9 @@ def schema_dump( "PYDANTIC_USE_FAILFAST", "UNSET", "BaseModel", + "DataclassProtocol", + "Empty", + "EmptyType", "FailFast", "FilterTypeT", "ModelDictListT", @@ -262,7 +383,14 @@ def schema_dump( "TypeAdapter", "UnsetType", "convert", + "dataclass_to_dict", + "extract_dataclass_fields", + "extract_dataclass_items", "get_type_adapter", + "is_dataclass", + "is_dataclass_instance", + "is_dataclass_with_field", + "is_dataclass_without_field", "is_dict", "is_dict_with_field", "is_dict_without_field", diff --git a/sqlspec/utils/dataclass.py b/sqlspec/utils/dataclass.py deleted file mode 100644 index fd0a6ea84..000000000 --- a/sqlspec/utils/dataclass.py +++ /dev/null @@ -1,138 +0,0 @@ -from __future__ import annotations - -from dataclasses import Field, fields -from typing import TYPE_CHECKING, ClassVar, Protocol, runtime_checkable - -from typing_extensions import TypeGuard - -from sqlspec.utils.empty import Empty - -if TYPE_CHECKING: - from collections.abc import Iterable - from collections.abc import Set as AbstractSet - from typing import Any - - -__all__ = ( - "extract_dataclass_fields", - "extract_dataclass_items", - "is_dataclass_instance", - "simple_asdict", -) - - -@runtime_checkable -class DataclassProtocol(Protocol): - """Protocol for instance checking dataclasses""" - - __dataclass_fields__: ClassVar[dict[str, Any]] - - -def is_dataclass_instance(obj: Any) -> TypeGuard[DataclassProtocol]: - """Check if an object is a dataclass instance. - - Args: - obj: An object to check. - - Returns: - True if the object is a dataclass instance. - """ - return hasattr(type(obj), "__dataclass_fields__") - - -def extract_dataclass_fields( - dt: DataclassProtocol, - exclude_none: bool = False, - exclude_empty: bool = False, - include: AbstractSet[str] | None = None, - exclude: AbstractSet[str] | None = None, -) -> tuple[Field[Any], ...]: - """Extract dataclass fields. - - Args: - dt: A dataclass instance. - exclude_none: Whether to exclude None values. - exclude_empty: Whether to exclude Empty values. - include: An iterable of fields to include. - exclude: An iterable of fields to exclude. - - - Returns: - A tuple of dataclass fields. - """ - include = include or set() - exclude = exclude or set() - - if common := (include & exclude): - msg = f"Fields {common} are both included and excluded." - raise ValueError(msg) - - dataclass_fields: Iterable[Field[Any]] = fields(dt) - if exclude_none: - dataclass_fields = (field for field in dataclass_fields if getattr(dt, field.name) is not None) - if exclude_empty: - dataclass_fields = (field for field in dataclass_fields if getattr(dt, field.name) is not Empty) - if include: - dataclass_fields = (field for field in dataclass_fields if field.name in include) - if exclude: - dataclass_fields = (field for field in dataclass_fields if field.name not in exclude) - - return tuple(dataclass_fields) - - -def extract_dataclass_items( - dt: DataclassProtocol, - exclude_none: bool = False, - exclude_empty: bool = False, - include: AbstractSet[str] | None = None, - exclude: AbstractSet[str] | None = None, -) -> tuple[tuple[str, Any], ...]: - """Extract dataclass name, value pairs. - - Unlike the 'asdict' method exports by the stdlib, this function does not pickle values. - - Args: - dt: A dataclass instance. - exclude_none: Whether to exclude None values. - exclude_empty: Whether to exclude Empty values. - include: An iterable of fields to include. - exclude: An iterable of fields to exclude. - - Returns: - A tuple of key/value pairs. - """ - dataclass_fields = extract_dataclass_fields(dt, exclude_none, exclude_empty, include, exclude) - return tuple((field.name, getattr(dt, field.name)) for field in dataclass_fields) - - -def simple_asdict( - obj: DataclassProtocol, - exclude_none: bool = False, - exclude_empty: bool = False, - convert_nested: bool = True, - exclude: set[str] | None = None, -) -> dict[str, Any]: - """Convert a dataclass to a dictionary. - - This method has important differences to the standard library version: - - it does not deepcopy values - - it does not recurse into collections - - Args: - obj: A dataclass instance. - exclude_none: Whether to exclude None values. - exclude_empty: Whether to exclude Empty values. - convert_nested: Whether to recursively convert nested dataclasses. - exclude: An iterable of fields to exclude. - - Returns: - A dictionary of key/value pairs. - """ - ret = {} - for field in extract_dataclass_fields(obj, exclude_none, exclude_empty, exclude=exclude): - value = getattr(obj, field.name) - if is_dataclass_instance(value) and convert_nested: - ret[field.name] = simple_asdict(value, exclude_none, exclude_empty) - else: - ret[field.name] = getattr(obj, field.name) - return ret diff --git a/sqlspec/utils/empty.py b/sqlspec/utils/empty.py deleted file mode 100644 index 91bc08cdf..000000000 --- a/sqlspec/utils/empty.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from enum import Enum -from typing import Final, Literal, Union - -from sqlspec.typing import UnsetType - -__all__ = ("Empty", "EmptyType") - - -class _EmptyEnum(Enum): - """A sentinel enum used as placeholder.""" - - EMPTY = 0 - - -EmptyType = Union[Literal[_EmptyEnum.EMPTY], UnsetType] -Empty: Final = _EmptyEnum.EMPTY diff --git a/sqlspec/utils/__init__.py b/tests/unit/__init__.py similarity index 100% rename from sqlspec/utils/__init__.py rename to tests/unit/__init__.py diff --git a/tests/unit/test_adapters/__init__.py b/tests/unit/test_adapters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/test_adapters/test_duckdb/__init__.py b/tests/unit/test_adapters/test_duckdb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/test_adapters/test_duckdb/test_config.py b/tests/unit/test_adapters/test_duckdb/test_config.py new file mode 100644 index 000000000..210a2a138 --- /dev/null +++ b/tests/unit/test_adapters/test_duckdb/test_config.py @@ -0,0 +1,268 @@ +"""Tests for DuckDB configuration.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest.mock import MagicMock, patch + +import pytest + +from sqlspec.adapters.duckdb.config import DuckDBConfig, ExtensionConfig +from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.typing import Empty + +if TYPE_CHECKING: + from collections.abc import Generator + + +@pytest.fixture +def mock_duckdb_connection() -> Generator[MagicMock, None, None]: + """Create a mock DuckDB connection.""" + with patch("duckdb.connect") as mock_connect: + connection = MagicMock() + mock_connect.return_value = connection + yield connection + + +class TestExtensionConfig: + """Test ExtensionConfig class.""" + + def test_default_values(self) -> None: + """Test default values for ExtensionConfig.""" + config = ExtensionConfig(name="test") + assert config.name == "test" + assert config.config is None + assert not config.force_install + assert config.repository is None + assert config.repository_url is None + assert config.version is None + + def test_from_dict_empty_config(self) -> None: + """Test from_dict with empty config.""" + config = ExtensionConfig.from_dict("test") + assert config.name == "test" + assert config.config is None + assert not config.force_install + + def test_from_dict_with_install_args(self) -> None: + """Test from_dict with installation arguments.""" + config = ExtensionConfig.from_dict( + "test", + { + "force_install": True, + "repository": "custom_repo", + "repository_url": "https://example.com", + "version": "1.0.0", + "config": {"some_setting": "value"}, + }, + ) + assert config.name == "test" + assert config.force_install + assert config.repository == "custom_repo" + assert config.repository_url == "https://example.com" + assert config.version == "1.0.0" + assert config.config == {"some_setting": "value"} + + def test_from_dict_with_only_config(self) -> None: + """Test from_dict with only config settings.""" + config = ExtensionConfig.from_dict("test", {"config": {"some_setting": "value"}}) + assert config.name == "test" + assert config.config == {"some_setting": "value"} + assert not config.force_install + + +class TestDuckDBConfig: + """Test DuckDBConfig class.""" + + def test_default_values(self) -> None: + """Test default values for DuckDBConfig.""" + config = DuckDBConfig() + assert config.database is Empty + assert config.read_only is Empty + assert config.config == {} + assert isinstance(config.extensions, list) + assert not config.extensions + + def test_connection_config_dict_defaults(self) -> None: + """Test connection_config_dict with default values.""" + config = DuckDBConfig() + assert config.connection_config_dict == {"database": ":memory:", "config": {}} + + def test_connection_config_dict_with_values(self) -> None: + """Test connection_config_dict with custom values.""" + config = DuckDBConfig(database="test.db", read_only=True) + assert config.connection_config_dict == {"database": "test.db", "read_only": True, "config": {}} + + def test_extensions_from_config_dict(self) -> None: + """Test extension configuration from config dictionary.""" + config = DuckDBConfig( + config={ + "extensions": { + "ext1": True, + "ext2": { + "force_install": True, + "repository": "repo", + "config": {"setting": "value"}, + }, + } + } + ) + assert isinstance(config.extensions, list) + assert len(config.extensions) == 2 + ext1 = next(ext for ext in config.extensions if ext.name == "ext1") + ext2 = next(ext for ext in config.extensions if ext.name == "ext2") + assert ext1.force_install + assert ext2.force_install + assert ext2.repository == "repo" + assert ext2.config == {"setting": "value"} + + def test_extensions_from_list(self) -> None: + """Test extension configuration from list.""" + config = DuckDBConfig(config={"extensions": ["ext1", "ext2"]}) + assert isinstance(config.extensions, list) + assert len(config.extensions) == 2 + assert all(isinstance(ext, ExtensionConfig) for ext in config.extensions) + assert {ext.name for ext in config.extensions} == {"ext1", "ext2"} + assert all(not ext.force_install for ext in config.extensions) + + def test_extensions_from_both_sources(self) -> None: + """Test extension configuration from both extensions and config.""" + config = DuckDBConfig( + extensions=[ExtensionConfig("ext1")], + config={"extensions": {"ext2": {"force_install": True}}}, + ) + assert isinstance(config.extensions, list) + assert len(config.extensions) == 2 + assert {ext.name for ext in config.extensions} == {"ext1", "ext2"} + + def test_duplicate_extensions_error(self) -> None: + """Test error on duplicate extension configuration.""" + with pytest.raises(ImproperConfigurationError, match="Configuring the same extension"): + DuckDBConfig( + extensions=[ExtensionConfig("ext1")], + config={"extensions": {"ext1": True}}, + ) + + def test_invalid_extensions_type_error(self) -> None: + """Test error on invalid extensions type.""" + with pytest.raises( + ImproperConfigurationError, + match="When configuring extensions in the 'config' dictionary, the value must be a dictionary or sequence of extension names", + ): + DuckDBConfig(config={"extensions": 123}) + + @pytest.mark.parametrize( + ("extension_config", "expected_calls"), + [ + ( + ExtensionConfig("test", force_install=True), + [ + ( + "install_extension", + { + "extension": "test", + "force_install": True, + "repository": None, + "repository_url": None, + "version": None, + }, + ), + ("load_extension", {}), + ], + ), + ( + ExtensionConfig("test", force_install=False), + [("load_extension", {})], + ), + ( + ExtensionConfig("test", force_install=True, config={"setting": "value"}), + [ + ( + "install_extension", + { + "extension": "test", + "force_install": True, + "repository": None, + "repository_url": None, + "version": None, + }, + ), + ("load_extension", {}), + ("execute", {"query": "SET setting=value"}), + ], + ), + ( + ExtensionConfig( + "test", + force_install=True, + repository="repo", + repository_url="url", + version="1.0", + ), + [ + ( + "install_extension", + { + "extension": "test", + "force_install": True, + "repository": "repo", + "repository_url": "url", + "version": "1.0", + }, + ), + ("load_extension", {}), + ], + ), + ], + ) + def test_configure_extensions( + self, + mock_duckdb_connection: MagicMock, + extension_config: ExtensionConfig, + expected_calls: list[tuple[str, dict[str, Any]]], + ) -> None: + """Test extension configuration with various settings.""" + config = DuckDBConfig(extensions=[extension_config]) + connection = config.create_connection() + + actual_calls = [] + for method_name, kwargs in expected_calls: + method = getattr(connection, method_name) + assert method.called + if method_name == "execute": + actual_calls.append((method_name, {"query": method.call_args.args[0]})) + else: + actual_calls.append((method_name, method.call_args.kwargs)) + + assert actual_calls == expected_calls + + def test_extension_configuration_error(self, mock_duckdb_connection: MagicMock) -> None: + """Test error handling during extension configuration.""" + mock_duckdb_connection.load_extension.side_effect = Exception("Test error") + config = DuckDBConfig(extensions=[ExtensionConfig("test")]) + + with pytest.raises(ImproperConfigurationError, match="Failed to configure extension test"): + config.create_connection() + + def test_connection_creation_error(self) -> None: + """Test error handling during connection creation.""" + with patch("duckdb.connect", side_effect=Exception("Test error")): + config = DuckDBConfig() + with pytest.raises(ImproperConfigurationError, match="Could not configure"): + config.create_connection() + + def test_connection_lifecycle(self, mock_duckdb_connection: MagicMock) -> None: + """Test connection lifecycle management.""" + config = DuckDBConfig() + + # Test lifespan + with config.lifespan(): + assert mock_duckdb_connection.close.call_count == 0 + assert mock_duckdb_connection.close.call_count == 1 + + # Test provide_connection + mock_duckdb_connection.reset_mock() + with config.provide_connection() as conn: + assert conn is mock_duckdb_connection + assert mock_duckdb_connection.close.call_count == 0 + assert mock_duckdb_connection.close.call_count == 1 diff --git a/tests/unit/test_typing.py b/tests/unit/test_typing.py new file mode 100644 index 000000000..15b45268b --- /dev/null +++ b/tests/unit/test_typing.py @@ -0,0 +1,276 @@ +"""Tests for typing utilities.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, ClassVar + +import pytest +from msgspec import Struct +from pydantic import BaseModel + +from sqlspec.typing import ( + Empty, + dataclass_to_dict, + extract_dataclass_fields, + extract_dataclass_items, + is_dataclass, + is_dataclass_instance, + is_dataclass_with_field, + is_dataclass_without_field, + is_dict, + is_dict_with_field, + is_dict_without_field, + is_msgspec_model, + is_msgspec_model_with_field, + is_msgspec_model_without_field, + is_pydantic_model, + is_pydantic_model_with_field, + is_pydantic_model_without_field, + schema_dump, +) + + +@dataclass +class SampleDataclass: + """Sample dataclass for testing.""" + + name: str + value: int | None = None + empty_field: Any = Empty + meta: ClassVar[str] = "test" + + +class SamplePydanticModel(BaseModel): + """Sample Pydantic model for testing.""" + + name: str + value: int | None = None + + +class SampleMsgspecModel(Struct): + """Sample Msgspec model for testing.""" + + name: str + value: int | None = None + + +@pytest.fixture +def sample_dataclass() -> SampleDataclass: + """Create a sample dataclass instance.""" + return SampleDataclass(name="test", value=42) + + +@pytest.fixture +def sample_pydantic() -> SamplePydanticModel: + """Create a sample Pydantic model instance.""" + return SamplePydanticModel(name="test", value=42) + + +@pytest.fixture +def sample_msgspec() -> SampleMsgspecModel: + """Create a sample Msgspec model instance.""" + return SampleMsgspecModel(name="test", value=42) + + +@pytest.fixture +def sample_dict() -> dict[str, Any]: + """Create a sample dictionary.""" + return {"name": "test", "value": 42} + + +class TestTypeChecking: + """Test type checking functions.""" + + def test_is_dataclass(self, sample_dataclass: SampleDataclass) -> None: + """Test dataclass type checking.""" + assert is_dataclass(sample_dataclass) + assert not is_dataclass({"name": "test"}) + + def test_is_dataclass_instance(self, sample_dataclass: SampleDataclass) -> None: + """Test dataclass instance checking.""" + assert is_dataclass_instance(sample_dataclass) + assert not is_dataclass_instance(SampleDataclass) + assert not is_dataclass_instance({"name": "test"}) + + def test_is_dataclass_with_field(self, sample_dataclass: SampleDataclass) -> None: + """Test dataclass field checking.""" + assert is_dataclass_with_field(sample_dataclass, "name") + assert not is_dataclass_with_field(sample_dataclass, "nonexistent") + + def test_is_dataclass_without_field(self, sample_dataclass: SampleDataclass) -> None: + """Test dataclass field absence checking.""" + assert is_dataclass_without_field(sample_dataclass, "nonexistent") + assert not is_dataclass_without_field(sample_dataclass, "name") + + def test_is_pydantic_model(self, sample_pydantic: SamplePydanticModel) -> None: + """Test Pydantic model type checking.""" + assert is_pydantic_model(sample_pydantic) + assert not is_pydantic_model({"name": "test"}) + + def test_is_pydantic_model_with_field(self, sample_pydantic: SamplePydanticModel) -> None: + """Test Pydantic model field checking.""" + assert is_pydantic_model_with_field(sample_pydantic, "name") + assert not is_pydantic_model_with_field(sample_pydantic, "nonexistent") + + def test_is_pydantic_model_without_field(self, sample_pydantic: SamplePydanticModel) -> None: + """Test Pydantic model field absence checking.""" + assert is_pydantic_model_without_field(sample_pydantic, "nonexistent") + assert not is_pydantic_model_without_field(sample_pydantic, "name") + + def test_is_msgspec_model(self, sample_msgspec: SampleMsgspecModel) -> None: + """Test Msgspec model type checking.""" + assert is_msgspec_model(sample_msgspec) + assert not is_msgspec_model({"name": "test"}) + + def test_is_msgspec_model_with_field(self, sample_msgspec: SampleMsgspecModel) -> None: + """Test Msgspec model field checking.""" + assert is_msgspec_model_with_field(sample_msgspec, "name") + assert not is_msgspec_model_with_field(sample_msgspec, "nonexistent") + + def test_is_msgspec_model_without_field(self, sample_msgspec: SampleMsgspecModel) -> None: + """Test Msgspec model field absence checking.""" + assert is_msgspec_model_without_field(sample_msgspec, "nonexistent") + assert not is_msgspec_model_without_field(sample_msgspec, "name") + + def test_is_dict(self, sample_dict: dict[str, Any]) -> None: + """Test dictionary type checking.""" + assert is_dict(sample_dict) + assert not is_dict([1, 2, 3]) + + def test_is_dict_with_field(self, sample_dict: dict[str, Any]) -> None: + """Test dictionary field checking.""" + assert is_dict_with_field(sample_dict, "name") + assert not is_dict_with_field(sample_dict, "nonexistent") + + def test_is_dict_without_field(self, sample_dict: dict[str, Any]) -> None: + """Test dictionary field absence checking.""" + assert is_dict_without_field(sample_dict, "nonexistent") + assert not is_dict_without_field(sample_dict, "name") + + +class TestDataclassUtils: + """Test dataclass utility functions.""" + + def test_extract_dataclass_fields(self, sample_dataclass: SampleDataclass) -> None: + """Test dataclass field extraction.""" + fields = extract_dataclass_fields(sample_dataclass) + assert len(fields) == 3 + assert all(f.name in {"name", "value", "empty_field"} for f in fields) + + # Test exclusions + fields_no_none = extract_dataclass_fields(sample_dataclass, exclude_none=True) + assert all(getattr(sample_dataclass, f.name) is not None for f in fields_no_none) + + fields_no_empty = extract_dataclass_fields(sample_dataclass, exclude_empty=True) + assert all(getattr(sample_dataclass, f.name) is not Empty for f in fields_no_empty) + + # Test include/exclude + fields_included = extract_dataclass_fields(sample_dataclass, include={"name"}) + assert len(fields_included) == 1 + assert fields_included[0].name == "name" + + fields_excluded = extract_dataclass_fields(sample_dataclass, exclude={"name"}) + assert all(f.name != "name" for f in fields_excluded) + + # Test conflicting include/exclude + with pytest.raises(ValueError, match="both included and excluded"): + extract_dataclass_fields(sample_dataclass, include={"name"}, exclude={"name"}) + + def test_extract_dataclass_items(self, sample_dataclass: SampleDataclass) -> None: + """Test dataclass item extraction.""" + items = extract_dataclass_items(sample_dataclass) + assert len(items) == 3 + assert dict(items) == { + "name": "test", + "value": 42, + "empty_field": Empty, + } + + def test_dataclass_to_dict(self) -> None: + """Test dataclass to dictionary conversion.""" + + @dataclass + class NestedDataclass: + """Nested dataclass for testing.""" + + x: int + y: int + + @dataclass + class ComplexDataclass: + """Complex dataclass for testing.""" + + name: str + nested: NestedDataclass + value: int | None = None + empty_field: Any = Empty + items: list[str] = field(default_factory=list) + + nested = NestedDataclass(x=1, y=2) + obj = ComplexDataclass( + name="test", + nested=nested, + value=42, + items=["a", "b"], + ) + + # Test basic conversion + result = dataclass_to_dict(obj) + assert result["name"] == "test" + assert result["value"] == 42 + assert result["empty_field"] is Empty + assert result["items"] == ["a", "b"] + assert isinstance(result["nested"], dict) + assert result["nested"] == {"x": 1, "y": 2} + + # Test with exclude_empty + result = dataclass_to_dict(obj, exclude_empty=True) + assert "empty_field" not in result + + # Test with exclude_none + obj.value = None + result = dataclass_to_dict(obj, exclude_none=True) + assert "value" not in result + + # Test without nested conversion + result = dataclass_to_dict(obj, convert_nested=False) + assert isinstance(result["nested"], NestedDataclass) + + # Test with exclusions + result = dataclass_to_dict(obj, exclude={"nested", "items"}) + assert "nested" not in result + assert "items" not in result + + +class TestSchemaDump: + """Test schema dumping functionality.""" + + def test_schema_dump_dataclass(self, sample_dataclass: SampleDataclass) -> None: + """Test schema dumping for dataclasses.""" + result = schema_dump(sample_dataclass) + assert result == { + "name": "test", + "value": 42, + } + + def test_schema_dump_pydantic(self, sample_pydantic: SamplePydanticModel) -> None: + """Test schema dumping for Pydantic models.""" + result = schema_dump(sample_pydantic) + assert result == { + "name": "test", + "value": 42, + } + + def test_schema_dump_msgspec(self, sample_msgspec: SampleMsgspecModel) -> None: + """Test schema dumping for Msgspec models.""" + result = schema_dump(sample_msgspec) + assert result == { + "name": "test", + "value": 42, + } + + def test_schema_dump_dict(self, sample_dict: dict[str, Any]) -> None: + """Test schema dumping for dictionaries.""" + result = schema_dump(sample_dict) + assert result == sample_dict