diff --git a/README.md b/README.md index c2e794c43..431857ac5 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,11 @@ etl_config = sql.add_config( ) ) with sql.provide_session(etl_config) as session: - result = session.select_one("SELECT open_prompt(?)", "Can you write a haiku about DuckDB?", schema_type=ChatMessage) + result = session.select_one( + "SELECT open_prompt(?)", + "Can you write a haiku about DuckDB?", + schema_type=ChatMessage + ) print(result) # result is a ChatMessage pydantic model ``` diff --git a/docs/PYPI_README.md b/docs/PYPI_README.md index c2e794c43..431857ac5 100644 --- a/docs/PYPI_README.md +++ b/docs/PYPI_README.md @@ -70,7 +70,11 @@ etl_config = sql.add_config( ) ) with sql.provide_session(etl_config) as session: - result = session.select_one("SELECT open_prompt(?)", "Can you write a haiku about DuckDB?", schema_type=ChatMessage) + result = session.select_one( + "SELECT open_prompt(?)", + "Can you write a haiku about DuckDB?", + schema_type=ChatMessage + ) print(result) # result is a ChatMessage pydantic model ``` diff --git a/sqlspec/__init__.py b/sqlspec/__init__.py index 911d2d843..35e87c2e9 100644 --- a/sqlspec/__init__.py +++ b/sqlspec/__init__.py @@ -1,4 +1,4 @@ -from sqlspec import adapters, base, exceptions, extensions, filters, typing, utils +from sqlspec import adapters, base, exceptions, extensions, filters, mixins, typing, utils from sqlspec.__metadata__ import __version__ from sqlspec.base import SQLSpec @@ -10,6 +10,7 @@ "exceptions", "extensions", "filters", + "mixins", "typing", "utils", ) diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index 6b3f7ec2b..1562bea8c 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -7,13 +7,14 @@ from adbc_driver_manager.dbapi import Connection, Cursor -from sqlspec.base import SyncArrowBulkOperationsMixin, SyncDriverAdapterProtocol, T +from sqlspec.base import SyncDriverAdapterProtocol from sqlspec.exceptions import ParameterStyleMismatchError, SQLParsingError +from sqlspec.mixins import SQLTranslatorMixin, SyncArrowBulkOperationsMixin from sqlspec.statement import SQLStatement from sqlspec.typing import ArrowTable, StatementParameterType if TYPE_CHECKING: - from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType + from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType, T __all__ = ("AdbcDriver",) @@ -33,7 +34,11 @@ ) -class AdbcDriver(SyncArrowBulkOperationsMixin["Connection"], SyncDriverAdapterProtocol["Connection"]): +class AdbcDriver( + SyncArrowBulkOperationsMixin["Connection"], + SQLTranslatorMixin["Connection"], + SyncDriverAdapterProtocol["Connection"], +): """ADBC Sync Driver Adapter.""" connection: Connection diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index fc96a99fc..860187505 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -1,19 +1,23 @@ from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload -from sqlspec.base import AsyncDriverAdapterProtocol, T +from sqlspec.base import AsyncDriverAdapterProtocol +from sqlspec.mixins import SQLTranslatorMixin if TYPE_CHECKING: from collections.abc import AsyncGenerator, Sequence from aiosqlite import Connection, Cursor - from sqlspec.typing import ModelDTOT, StatementParameterType + from sqlspec.typing import ModelDTOT, StatementParameterType, T __all__ = ("AiosqliteDriver",) -class AiosqliteDriver(AsyncDriverAdapterProtocol["Connection"]): +class AiosqliteDriver( + SQLTranslatorMixin["Connection"], + AsyncDriverAdapterProtocol["Connection"], +): """SQLite Async Driver Adapter.""" connection: "Connection" diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index 19552c6e4..c52bdd390 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -3,18 +3,22 @@ from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload -from sqlspec.base import AsyncDriverAdapterProtocol, T +from sqlspec.base import AsyncDriverAdapterProtocol +from sqlspec.mixins import SQLTranslatorMixin if TYPE_CHECKING: from asyncmy import Connection from asyncmy.cursors import Cursor - from sqlspec.typing import ModelDTOT, StatementParameterType + from sqlspec.typing import ModelDTOT, StatementParameterType, T __all__ = ("AsyncmyDriver",) -class AsyncmyDriver(AsyncDriverAdapterProtocol["Connection"]): +class AsyncmyDriver( + SQLTranslatorMixin["Connection"], + AsyncDriverAdapterProtocol["Connection"], +): """Asyncmy MySQL/MariaDB Driver Adapter.""" connection: "Connection" diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index e5b19564d..6d23505e3 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -5,8 +5,9 @@ from asyncpg import Connection from typing_extensions import TypeAlias -from sqlspec.base import AsyncDriverAdapterProtocol, T +from sqlspec.base import AsyncDriverAdapterProtocol from sqlspec.exceptions import SQLParsingError +from sqlspec.mixins import SQLTranslatorMixin from sqlspec.statement import PARAM_REGEX, SQLStatement if TYPE_CHECKING: @@ -15,7 +16,7 @@ from asyncpg.connection import Connection from asyncpg.pool import PoolConnectionProxy - from sqlspec.typing import ModelDTOT, StatementParameterType + from sqlspec.typing import ModelDTOT, StatementParameterType, T __all__ = ("AsyncpgConnection", "AsyncpgDriver") @@ -35,7 +36,10 @@ AsyncpgConnection: TypeAlias = "Union[Connection[Any], PoolConnectionProxy[Any]]" # pyright: ignore[reportMissingTypeArgument] -class AsyncpgDriver(AsyncDriverAdapterProtocol["AsyncpgConnection"]): +class AsyncpgDriver( + SQLTranslatorMixin["AsyncpgConnection"], + AsyncDriverAdapterProtocol["AsyncpgConnection"], +): """AsyncPG Postgres Driver Adapter.""" connection: "AsyncpgConnection" diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index dd9c6eb1e..6f0238103 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -1,19 +1,25 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload -from sqlspec.base import SyncArrowBulkOperationsMixin, SyncDriverAdapterProtocol, T +from sqlspec.base import SyncDriverAdapterProtocol +from sqlspec.mixins import SQLTranslatorMixin, SyncArrowBulkOperationsMixin +from sqlspec.typing import ArrowTable, StatementParameterType if TYPE_CHECKING: from collections.abc import Generator, Sequence from duckdb import DuckDBPyConnection - from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType + from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType, T __all__ = ("DuckDBDriver",) -class DuckDBDriver(SyncArrowBulkOperationsMixin["DuckDBPyConnection"], SyncDriverAdapterProtocol["DuckDBPyConnection"]): +class DuckDBDriver( + SyncArrowBulkOperationsMixin["DuckDBPyConnection"], + SQLTranslatorMixin["DuckDBPyConnection"], + SyncDriverAdapterProtocol["DuckDBPyConnection"], +): """DuckDB Sync Driver Adapter.""" connection: "DuckDBPyConnection" diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 32cdb2e9e..5d43d8757 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -1,14 +1,9 @@ from contextlib import asynccontextmanager, contextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload -from sqlspec.base import ( - AsyncArrowBulkOperationsMixin, - AsyncDriverAdapterProtocol, - SyncArrowBulkOperationsMixin, - SyncDriverAdapterProtocol, - T, -) -from sqlspec.typing import ArrowTable, StatementParameterType +from sqlspec.base import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol +from sqlspec.mixins import AsyncArrowBulkOperationsMixin, SQLTranslatorMixin, SyncArrowBulkOperationsMixin +from sqlspec.typing import ArrowTable, StatementParameterType, T if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator, Sequence @@ -21,7 +16,11 @@ __all__ = ("OracleAsyncDriver", "OracleSyncDriver") -class OracleSyncDriver(SyncArrowBulkOperationsMixin["Connection"], SyncDriverAdapterProtocol["Connection"]): +class OracleSyncDriver( + SyncArrowBulkOperationsMixin["Connection"], + SQLTranslatorMixin["Connection"], + SyncDriverAdapterProtocol["Connection"], +): """Oracle Sync Driver Adapter.""" connection: "Connection" @@ -433,7 +432,9 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] class OracleAsyncDriver( - AsyncArrowBulkOperationsMixin["AsyncConnection"], AsyncDriverAdapterProtocol["AsyncConnection"] + AsyncArrowBulkOperationsMixin["AsyncConnection"], + SQLTranslatorMixin["AsyncConnection"], + AsyncDriverAdapterProtocol["AsyncConnection"], ): """Oracle Async Driver Adapter.""" diff --git a/sqlspec/adapters/psqlpy/driver.py b/sqlspec/adapters/psqlpy/driver.py index 0b3f82a2c..9fef18717 100644 --- a/sqlspec/adapters/psqlpy/driver.py +++ b/sqlspec/adapters/psqlpy/driver.py @@ -7,8 +7,9 @@ from psqlpy.exceptions import RustPSQLDriverPyBaseError -from sqlspec.base import AsyncDriverAdapterProtocol, T +from sqlspec.base import AsyncDriverAdapterProtocol from sqlspec.exceptions import SQLParsingError +from sqlspec.mixins import SQLTranslatorMixin from sqlspec.statement import PARAM_REGEX, SQLStatement if TYPE_CHECKING: @@ -16,7 +17,7 @@ from psqlpy import Connection, QueryResult - from sqlspec.typing import ModelDTOT, StatementParameterType + from sqlspec.typing import ModelDTOT, StatementParameterType, T __all__ = ("PsqlpyDriver",) @@ -33,7 +34,10 @@ logger = logging.getLogger("sqlspec") -class PsqlpyDriver(AsyncDriverAdapterProtocol["Connection"]): +class PsqlpyDriver( + SQLTranslatorMixin["Connection"], + AsyncDriverAdapterProtocol["Connection"], +): """Psqlpy Postgres Driver Adapter.""" connection: "Connection" diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index 07808cbc3..4d94c3bbe 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -4,8 +4,9 @@ from psycopg.rows import dict_row -from sqlspec.base import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol, T +from sqlspec.base import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol from sqlspec.exceptions import SQLParsingError +from sqlspec.mixins import SQLTranslatorMixin from sqlspec.statement import PARAM_REGEX, SQLStatement if TYPE_CHECKING: @@ -13,14 +14,14 @@ from psycopg import AsyncConnection, Connection - from sqlspec.typing import ModelDTOT, StatementParameterType + from sqlspec.typing import ModelDTOT, StatementParameterType, T logger = logging.getLogger("sqlspec") __all__ = ("PsycopgAsyncDriver", "PsycopgSyncDriver") -class PsycopgParameterParser: +class PsycopgDriverBase: dialect: str def _process_sql_params( @@ -76,7 +77,11 @@ def _process_sql_params( return processed_sql, processed_params -class PsycopgSyncDriver(PsycopgParameterParser, SyncDriverAdapterProtocol["Connection"]): +class PsycopgSyncDriver( + PsycopgDriverBase, + SQLTranslatorMixin["Connection"], + SyncDriverAdapterProtocol["Connection"], +): """Psycopg Sync Driver Adapter.""" connection: "Connection" @@ -482,7 +487,11 @@ def execute_script( return str(cursor.statusmessage) if cursor.statusmessage is not None else "DONE" -class PsycopgAsyncDriver(PsycopgParameterParser, AsyncDriverAdapterProtocol["AsyncConnection"]): +class PsycopgAsyncDriver( + PsycopgDriverBase, + SQLTranslatorMixin["AsyncConnection"], + AsyncDriverAdapterProtocol["AsyncConnection"], +): """Psycopg Async Driver Adapter.""" connection: "AsyncConnection" diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index a9342ae00..f78f00239 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -2,17 +2,21 @@ from sqlite3 import Connection, Cursor from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload -from sqlspec.base import SyncDriverAdapterProtocol, T +from sqlspec.base import SyncDriverAdapterProtocol +from sqlspec.mixins import SQLTranslatorMixin if TYPE_CHECKING: from collections.abc import Generator, Sequence - from sqlspec.typing import ModelDTOT, StatementParameterType + from sqlspec.typing import ModelDTOT, StatementParameterType, T __all__ = ("SqliteDriver",) -class SqliteDriver(SyncDriverAdapterProtocol["Connection"]): +class SqliteDriver( + SQLTranslatorMixin["Connection"], + SyncDriverAdapterProtocol["Connection"], +): """SQLite Sync Driver Adapter.""" connection: "Connection" diff --git a/sqlspec/base.py b/sqlspec/base.py index 5e4c09a96..ba33389eb 100644 --- a/sqlspec/base.py +++ b/sqlspec/base.py @@ -20,16 +20,14 @@ from sqlspec.exceptions import NotFoundError from sqlspec.statement import SQLStatement -from sqlspec.typing import ModelDTOT, StatementParameterType +from sqlspec.typing import ConnectionT, ModelDTOT, PoolT, StatementParameterType, T from sqlspec.utils.sync_tools import maybe_async_ if TYPE_CHECKING: from contextlib import AbstractAsyncContextManager, AbstractContextManager - from pyarrow import Table as ArrowTable __all__ = ( - "AsyncArrowBulkOperationsMixin", "AsyncDatabaseConfig", "AsyncDriverAdapterProtocol", "CommonDriverAttributes", @@ -39,15 +37,10 @@ "NoPoolSyncConfig", "SQLSpec", "SQLStatement", - "SyncArrowBulkOperationsMixin", "SyncDatabaseConfig", "SyncDriverAdapterProtocol", ) -T = TypeVar("T") -ConnectionT = TypeVar("ConnectionT") -PoolT = TypeVar("PoolT") -PoolT_co = TypeVar("PoolT_co", covariant=True) AsyncConfigT = TypeVar("AsyncConfigT", bound="Union[AsyncDatabaseConfig[Any, Any, Any], NoPoolAsyncConfig[Any, Any]]") SyncConfigT = TypeVar("SyncConfigT", bound="Union[SyncDatabaseConfig[Any, Any, Any], NoPoolSyncConfig[Any, Any]]") ConfigT = TypeVar( @@ -558,35 +551,6 @@ def _process_sql_params( return stmt.process() -class SyncArrowBulkOperationsMixin(Generic[ConnectionT]): - """Mixin for sync drivers supporting bulk Apache Arrow operations.""" - - __supports_arrow__: "ClassVar[bool]" = True - - @abstractmethod - def select_arrow( # pyright: ignore[reportUnknownParameterType] - self, - sql: str, - parameters: "Optional[StatementParameterType]" = None, - /, - *, - connection: "Optional[ConnectionT]" = None, - **kwargs: Any, - ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType] - """Execute a SQL query and return results as an Apache Arrow Table. - - Args: - sql: The SQL query string. - parameters: Parameters for the query. - connection: Optional connection override. - **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. - - Returns: - An Apache Arrow Table containing the query results. - """ - raise NotImplementedError - - class SyncDriverAdapterProtocol(CommonDriverAttributes[ConnectionT], ABC, Generic[ConnectionT]): connection: "ConnectionT" @@ -844,35 +808,6 @@ def execute_script( ) -> str: ... -class AsyncArrowBulkOperationsMixin(Generic[ConnectionT]): - """Mixin for async drivers supporting bulk Apache Arrow operations.""" - - __supports_arrow__: "ClassVar[bool]" = True - - @abstractmethod - async def select_arrow( # pyright: ignore[reportUnknownParameterType] - self, - sql: str, - parameters: "Optional[StatementParameterType]" = None, - /, - *, - connection: "Optional[ConnectionT]" = None, - **kwargs: Any, - ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType] - """Execute a SQL query and return results as an Apache Arrow Table. - - Args: - sql: The SQL query string. - parameters: Parameters for the query. - connection: Optional connection override. - **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. - - Returns: - An Apache Arrow Table containing the query results. - """ - raise NotImplementedError - - class AsyncDriverAdapterProtocol(CommonDriverAttributes[ConnectionT], ABC, Generic[ConnectionT]): connection: "ConnectionT" diff --git a/sqlspec/exceptions.py b/sqlspec/exceptions.py index 73eac9e32..67b45c6a3 100644 --- a/sqlspec/exceptions.py +++ b/sqlspec/exceptions.py @@ -78,6 +78,15 @@ def __init__(self, message: Optional[str] = None) -> None: super().__init__(message) +class SQLConversionError(SQLSpecError): + """Issues converting SQL statements.""" + + def __init__(self, message: Optional[str] = None) -> None: + if message is None: + message = "Issues converting SQL statement." + super().__init__(message) + + class ParameterStyleMismatchError(SQLSpecError): """Error when parameter style doesn't match SQL placeholder style. diff --git a/sqlspec/extensions/litestar/config.py b/sqlspec/extensions/litestar/config.py index 59e0ed167..a39834ca4 100644 --- a/sqlspec/extensions/litestar/config.py +++ b/sqlspec/extensions/litestar/config.py @@ -1,10 +1,6 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Callable, Literal, Optional, Union -from sqlspec.base import ( - ConnectionT, - PoolT, -) from sqlspec.exceptions import ImproperConfigurationError from sqlspec.extensions.litestar.handlers import ( autocommit_handler_maker, @@ -23,13 +19,9 @@ from litestar.datastructures.state import State from litestar.types import BeforeMessageSendHookHandler, Scope - from sqlspec.base import ( - AsyncConfigT, - ConnectionT, - DriverT, - PoolT, - SyncConfigT, - ) + from sqlspec.base import AsyncConfigT, DriverT, SyncConfigT + from sqlspec.typing import ConnectionT, PoolT + CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"] DEFAULT_COMMIT_MODE: CommitMode = "manual" diff --git a/sqlspec/extensions/litestar/handlers.py b/sqlspec/extensions/litestar/handlers.py index 1d5e336f9..0d7cab1ea 100644 --- a/sqlspec/extensions/litestar/handlers.py +++ b/sqlspec/extensions/litestar/handlers.py @@ -19,7 +19,8 @@ from litestar.datastructures.state import State from litestar.types import Message, Scope - from sqlspec.base import ConnectionT, DatabaseConfigProtocol, DriverT, PoolT + from sqlspec.base import DatabaseConfigProtocol, DriverT + from sqlspec.typing import ConnectionT, PoolT SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE} diff --git a/sqlspec/extensions/litestar/plugin.py b/sqlspec/extensions/litestar/plugin.py index fd2b5564d..11e6820f2 100644 --- a/sqlspec/extensions/litestar/plugin.py +++ b/sqlspec/extensions/litestar/plugin.py @@ -5,15 +5,14 @@ from sqlspec.base import ( AsyncConfigT, - ConnectionT, DatabaseConfigProtocol, DriverT, - PoolT, SyncConfigT, ) from sqlspec.base import SQLSpec as SQLSpecBase from sqlspec.exceptions import ImproperConfigurationError from sqlspec.extensions.litestar.config import DatabaseConfig +from sqlspec.typing import ConnectionT, PoolT if TYPE_CHECKING: from click import Group diff --git a/sqlspec/mixins.py b/sqlspec/mixins.py new file mode 100644 index 000000000..067bb5a0f --- /dev/null +++ b/sqlspec/mixins.py @@ -0,0 +1,156 @@ +from abc import abstractmethod +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Optional, +) + +from sqlglot import parse_one +from sqlglot.dialects.dialect import DialectType + +from sqlspec.exceptions import SQLConversionError, SQLParsingError +from sqlspec.typing import ConnectionT, StatementParameterType + +if TYPE_CHECKING: + from sqlspec.typing import ArrowTable + +__all__ = ( + "AsyncArrowBulkOperationsMixin", + "AsyncParquetExportMixin", + "SQLTranslatorMixin", + "SyncArrowBulkOperationsMixin", + "SyncParquetExportMixin", +) + + +class SyncArrowBulkOperationsMixin(Generic[ConnectionT]): + """Mixin for sync drivers supporting bulk Apache Arrow operations.""" + + __supports_arrow__: "ClassVar[bool]" = True + + @abstractmethod + def select_arrow( # pyright: ignore[reportUnknownParameterType] + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + **kwargs: Any, + ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType] + """Execute a SQL query and return results as an Apache Arrow Table. + + Args: + sql: The SQL query string. + parameters: Parameters for the query. + connection: Optional connection override. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. + + Returns: + An Apache Arrow Table containing the query results. + """ + raise NotImplementedError + + +class AsyncArrowBulkOperationsMixin(Generic[ConnectionT]): + """Mixin for async drivers supporting bulk Apache Arrow operations.""" + + __supports_arrow__: "ClassVar[bool]" = True + + @abstractmethod + async def select_arrow( # pyright: ignore[reportUnknownParameterType] + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + **kwargs: Any, + ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType] + """Execute a SQL query and return results as an Apache Arrow Table. + + Args: + sql: The SQL query string. + parameters: Parameters for the query. + connection: Optional connection override. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. + + Returns: + An Apache Arrow Table containing the query results. + """ + raise NotImplementedError + + +class SyncParquetExportMixin(Generic[ConnectionT]): + """Mixin for sync drivers supporting Parquet export.""" + + @abstractmethod + def select_to_parquet( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + **kwargs: Any, + ) -> None: + """Export a SQL query to a Parquet file.""" + raise NotImplementedError + + +class AsyncParquetExportMixin(Generic[ConnectionT]): + """Mixin for async drivers supporting Parquet export.""" + + @abstractmethod + async def select_to_parquet( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + **kwargs: Any, + ) -> None: + """Export a SQL query to a Parquet file.""" + raise NotImplementedError + + +class SQLTranslatorMixin(Generic[ConnectionT]): + """Mixin for drivers supporting SQL translation.""" + + dialect: str + + def convert_to_dialect( + self, + sql: str, + to_dialect: DialectType = None, + pretty: bool = True, + ) -> str: + """Convert a SQL query to a different dialect. + + Args: + sql: The SQL query string to convert. + to_dialect: The target dialect to convert to. + pretty: Whether to pretty-print the SQL query. + + Returns: + The converted SQL query string. + + Raises: + SQLParsingError: If the SQL query cannot be parsed. + SQLConversionError: If the SQL query cannot be converted to the target dialect. + """ + try: + parsed = parse_one(sql, dialect=self.dialect) + except Exception as e: + error_msg = f"Failed to parse SQL: {e!s}" + raise SQLParsingError(error_msg) from e + if to_dialect is None: + to_dialect = self.dialect + try: + return parsed.sql(dialect=to_dialect, pretty=pretty) + except Exception as e: + error_msg = f"Failed to convert SQL to {to_dialect}: {e!s}" + raise SQLConversionError(error_msg) from e diff --git a/sqlspec/typing.py b/sqlspec/typing.py index ee0aae7a1..92de8f9e3 100644 --- a/sqlspec/typing.py +++ b/sqlspec/typing.py @@ -33,8 +33,26 @@ T = TypeVar("T") +ConnectionT = TypeVar("ConnectionT") +"""Type variable for connection types. -ModelT = TypeVar("ModelT", bound="Union[Struct, BaseModel, DataclassProtocol]") +:class:`~sqlspec.typing.ConnectionT` +""" +PoolT = TypeVar("PoolT") +"""Type variable for pool types. + +:class:`~sqlspec.typing.PoolT` +""" +PoolT_co = TypeVar("PoolT_co", covariant=True) +"""Type variable for covariant pool types. + +:class:`~sqlspec.typing.PoolT_co` +""" +ModelT = TypeVar("ModelT", bound="Union[dict[str, Any], Struct, BaseModel, DataclassProtocol]") +"""Type variable for model types. + +:class:`dict[str, Any]` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel` | :class:`DataclassProtocol` +""" FilterTypeT = TypeVar("FilterTypeT", bound="StatementFilter") """Type variable for filter types. diff --git a/uv.lock b/uv.lock index e819b6add..eceb80cbc 100644 --- a/uv.lock +++ b/uv.lock @@ -3145,11 +3145,11 @@ wheels = [ [[package]] name = "soupsieve" -version = "2.6" +version = "2.7" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d7/ce/fbaeed4f9fb8b2daa961f90591662df6a86c1abf25c548329a86920aedfb/soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb", size = 101569 } +sdist = { url = "https://files.pythonhosted.org/packages/3f/f4/4a80cd6ef364b2e8b65b15816a843c0980f7a5a2b4dc701fc574952aa19f/soupsieve-2.7.tar.gz", hash = "sha256:ad282f9b6926286d2ead4750552c8a6142bc4c783fd66b0293547c8fe6ae126a", size = 103418 } wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/c2/fe97d779f3ef3b15f05c94a2f1e3d21732574ed441687474db9d342a7315/soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9", size = 36186 }, + { url = "https://files.pythonhosted.org/packages/e7/9c/0e6afc12c269578be5c0c1c9f4b49a8d32770a080260c333ac04cc1c832d/soupsieve-2.7-py3-none-any.whl", hash = "sha256:6e60cc5c1ffaf1cebcc12e8188320b72071e922c2e897f737cadce79ad5d30c4", size = 36677 }, ] [[package]]