Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ etl_config = sql.add_config(
)
)
with sql.provide_session(etl_config) as session:
result = session.select_one("SELECT open_prompt(?)", "Can you write a haiku about DuckDB?", schema_type=ChatMessage)
result = session.select_one(
"SELECT open_prompt(?)",
"Can you write a haiku about DuckDB?",
schema_type=ChatMessage
)
print(result) # result is a ChatMessage pydantic model
```

Expand Down
6 changes: 5 additions & 1 deletion docs/PYPI_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ etl_config = sql.add_config(
)
)
with sql.provide_session(etl_config) as session:
result = session.select_one("SELECT open_prompt(?)", "Can you write a haiku about DuckDB?", schema_type=ChatMessage)
result = session.select_one(
"SELECT open_prompt(?)",
"Can you write a haiku about DuckDB?",
schema_type=ChatMessage
)
print(result) # result is a ChatMessage pydantic model
```

Expand Down
3 changes: 2 additions & 1 deletion sqlspec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlspec import adapters, base, exceptions, extensions, filters, typing, utils
from sqlspec import adapters, base, exceptions, extensions, filters, mixins, typing, utils
from sqlspec.__metadata__ import __version__
from sqlspec.base import SQLSpec

Expand All @@ -10,6 +10,7 @@
"exceptions",
"extensions",
"filters",
"mixins",
"typing",
"utils",
)
11 changes: 8 additions & 3 deletions sqlspec/adapters/adbc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@

from adbc_driver_manager.dbapi import Connection, Cursor

from sqlspec.base import SyncArrowBulkOperationsMixin, SyncDriverAdapterProtocol, T
from sqlspec.base import SyncDriverAdapterProtocol
from sqlspec.exceptions import ParameterStyleMismatchError, SQLParsingError
from sqlspec.mixins import SQLTranslatorMixin, SyncArrowBulkOperationsMixin
from sqlspec.statement import SQLStatement
from sqlspec.typing import ArrowTable, StatementParameterType

if TYPE_CHECKING:
from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType
from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType, T

__all__ = ("AdbcDriver",)

Expand All @@ -33,7 +34,11 @@
)


class AdbcDriver(SyncArrowBulkOperationsMixin["Connection"], SyncDriverAdapterProtocol["Connection"]):
class AdbcDriver(
SyncArrowBulkOperationsMixin["Connection"],
SQLTranslatorMixin["Connection"],
SyncDriverAdapterProtocol["Connection"],
):
"""ADBC Sync Driver Adapter."""

connection: Connection
Expand Down
10 changes: 7 additions & 3 deletions sqlspec/adapters/aiosqlite/driver.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload

from sqlspec.base import AsyncDriverAdapterProtocol, T
from sqlspec.base import AsyncDriverAdapterProtocol
from sqlspec.mixins import SQLTranslatorMixin

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Sequence

from aiosqlite import Connection, Cursor

from sqlspec.typing import ModelDTOT, StatementParameterType
from sqlspec.typing import ModelDTOT, StatementParameterType, T

__all__ = ("AiosqliteDriver",)


class AiosqliteDriver(AsyncDriverAdapterProtocol["Connection"]):
class AiosqliteDriver(
SQLTranslatorMixin["Connection"],
AsyncDriverAdapterProtocol["Connection"],
):
"""SQLite Async Driver Adapter."""

connection: "Connection"
Expand Down
10 changes: 7 additions & 3 deletions sqlspec/adapters/asyncmy/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,22 @@
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload

from sqlspec.base import AsyncDriverAdapterProtocol, T
from sqlspec.base import AsyncDriverAdapterProtocol
from sqlspec.mixins import SQLTranslatorMixin

if TYPE_CHECKING:
from asyncmy import Connection
from asyncmy.cursors import Cursor

from sqlspec.typing import ModelDTOT, StatementParameterType
from sqlspec.typing import ModelDTOT, StatementParameterType, T

__all__ = ("AsyncmyDriver",)


class AsyncmyDriver(AsyncDriverAdapterProtocol["Connection"]):
class AsyncmyDriver(
SQLTranslatorMixin["Connection"],
AsyncDriverAdapterProtocol["Connection"],
):
"""Asyncmy MySQL/MariaDB Driver Adapter."""

connection: "Connection"
Expand Down
10 changes: 7 additions & 3 deletions sqlspec/adapters/asyncpg/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from asyncpg import Connection
from typing_extensions import TypeAlias

from sqlspec.base import AsyncDriverAdapterProtocol, T
from sqlspec.base import AsyncDriverAdapterProtocol
from sqlspec.exceptions import SQLParsingError
from sqlspec.mixins import SQLTranslatorMixin
from sqlspec.statement import PARAM_REGEX, SQLStatement

if TYPE_CHECKING:
Expand All @@ -15,7 +16,7 @@
from asyncpg.connection import Connection
from asyncpg.pool import PoolConnectionProxy

from sqlspec.typing import ModelDTOT, StatementParameterType
from sqlspec.typing import ModelDTOT, StatementParameterType, T

__all__ = ("AsyncpgConnection", "AsyncpgDriver")

Expand All @@ -35,7 +36,10 @@
AsyncpgConnection: TypeAlias = "Union[Connection[Any], PoolConnectionProxy[Any]]" # pyright: ignore[reportMissingTypeArgument]


class AsyncpgDriver(AsyncDriverAdapterProtocol["AsyncpgConnection"]):
class AsyncpgDriver(
SQLTranslatorMixin["AsyncpgConnection"],
AsyncDriverAdapterProtocol["AsyncpgConnection"],
):
"""AsyncPG Postgres Driver Adapter."""

connection: "AsyncpgConnection"
Expand Down
12 changes: 9 additions & 3 deletions sqlspec/adapters/duckdb/driver.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload

from sqlspec.base import SyncArrowBulkOperationsMixin, SyncDriverAdapterProtocol, T
from sqlspec.base import SyncDriverAdapterProtocol
from sqlspec.mixins import SQLTranslatorMixin, SyncArrowBulkOperationsMixin
from sqlspec.typing import ArrowTable, StatementParameterType

if TYPE_CHECKING:
from collections.abc import Generator, Sequence

from duckdb import DuckDBPyConnection

from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType
from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType, T

__all__ = ("DuckDBDriver",)


class DuckDBDriver(SyncArrowBulkOperationsMixin["DuckDBPyConnection"], SyncDriverAdapterProtocol["DuckDBPyConnection"]):
class DuckDBDriver(
SyncArrowBulkOperationsMixin["DuckDBPyConnection"],
SQLTranslatorMixin["DuckDBPyConnection"],
SyncDriverAdapterProtocol["DuckDBPyConnection"],
):
"""DuckDB Sync Driver Adapter."""

connection: "DuckDBPyConnection"
Expand Down
21 changes: 11 additions & 10 deletions sqlspec/adapters/oracledb/driver.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from contextlib import asynccontextmanager, contextmanager
from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload

from sqlspec.base import (
AsyncArrowBulkOperationsMixin,
AsyncDriverAdapterProtocol,
SyncArrowBulkOperationsMixin,
SyncDriverAdapterProtocol,
T,
)
from sqlspec.typing import ArrowTable, StatementParameterType
from sqlspec.base import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
from sqlspec.mixins import AsyncArrowBulkOperationsMixin, SQLTranslatorMixin, SyncArrowBulkOperationsMixin
from sqlspec.typing import ArrowTable, StatementParameterType, T

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Generator, Sequence
Expand All @@ -21,7 +16,11 @@
__all__ = ("OracleAsyncDriver", "OracleSyncDriver")


class OracleSyncDriver(SyncArrowBulkOperationsMixin["Connection"], SyncDriverAdapterProtocol["Connection"]):
class OracleSyncDriver(
SyncArrowBulkOperationsMixin["Connection"],
SQLTranslatorMixin["Connection"],
SyncDriverAdapterProtocol["Connection"],
):
"""Oracle Sync Driver Adapter."""

connection: "Connection"
Expand Down Expand Up @@ -433,7 +432,9 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType]


class OracleAsyncDriver(
AsyncArrowBulkOperationsMixin["AsyncConnection"], AsyncDriverAdapterProtocol["AsyncConnection"]
AsyncArrowBulkOperationsMixin["AsyncConnection"],
SQLTranslatorMixin["AsyncConnection"],
AsyncDriverAdapterProtocol["AsyncConnection"],
):
"""Oracle Async Driver Adapter."""

Expand Down
10 changes: 7 additions & 3 deletions sqlspec/adapters/psqlpy/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@

from psqlpy.exceptions import RustPSQLDriverPyBaseError

from sqlspec.base import AsyncDriverAdapterProtocol, T
from sqlspec.base import AsyncDriverAdapterProtocol
from sqlspec.exceptions import SQLParsingError
from sqlspec.mixins import SQLTranslatorMixin
from sqlspec.statement import PARAM_REGEX, SQLStatement

if TYPE_CHECKING:
from collections.abc import Sequence

from psqlpy import Connection, QueryResult

from sqlspec.typing import ModelDTOT, StatementParameterType
from sqlspec.typing import ModelDTOT, StatementParameterType, T

__all__ = ("PsqlpyDriver",)

Expand All @@ -33,7 +34,10 @@
logger = logging.getLogger("sqlspec")


class PsqlpyDriver(AsyncDriverAdapterProtocol["Connection"]):
class PsqlpyDriver(
SQLTranslatorMixin["Connection"],
AsyncDriverAdapterProtocol["Connection"],
):
"""Psqlpy Postgres Driver Adapter."""

connection: "Connection"
Expand Down
19 changes: 14 additions & 5 deletions sqlspec/adapters/psycopg/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,24 @@

from psycopg.rows import dict_row

from sqlspec.base import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol, T
from sqlspec.base import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
from sqlspec.exceptions import SQLParsingError
from sqlspec.mixins import SQLTranslatorMixin
from sqlspec.statement import PARAM_REGEX, SQLStatement

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Generator, Sequence

from psycopg import AsyncConnection, Connection

from sqlspec.typing import ModelDTOT, StatementParameterType
from sqlspec.typing import ModelDTOT, StatementParameterType, T

logger = logging.getLogger("sqlspec")

__all__ = ("PsycopgAsyncDriver", "PsycopgSyncDriver")


class PsycopgParameterParser:
class PsycopgDriverBase:
dialect: str

def _process_sql_params(
Expand Down Expand Up @@ -76,7 +77,11 @@ def _process_sql_params(
return processed_sql, processed_params


class PsycopgSyncDriver(PsycopgParameterParser, SyncDriverAdapterProtocol["Connection"]):
class PsycopgSyncDriver(
PsycopgDriverBase,
SQLTranslatorMixin["Connection"],
SyncDriverAdapterProtocol["Connection"],
):
"""Psycopg Sync Driver Adapter."""

connection: "Connection"
Expand Down Expand Up @@ -482,7 +487,11 @@ def execute_script(
return str(cursor.statusmessage) if cursor.statusmessage is not None else "DONE"


class PsycopgAsyncDriver(PsycopgParameterParser, AsyncDriverAdapterProtocol["AsyncConnection"]):
class PsycopgAsyncDriver(
PsycopgDriverBase,
SQLTranslatorMixin["AsyncConnection"],
AsyncDriverAdapterProtocol["AsyncConnection"],
):
"""Psycopg Async Driver Adapter."""

connection: "AsyncConnection"
Expand Down
10 changes: 7 additions & 3 deletions sqlspec/adapters/sqlite/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@
from sqlite3 import Connection, Cursor
from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload

from sqlspec.base import SyncDriverAdapterProtocol, T
from sqlspec.base import SyncDriverAdapterProtocol
from sqlspec.mixins import SQLTranslatorMixin

if TYPE_CHECKING:
from collections.abc import Generator, Sequence

from sqlspec.typing import ModelDTOT, StatementParameterType
from sqlspec.typing import ModelDTOT, StatementParameterType, T

__all__ = ("SqliteDriver",)


class SqliteDriver(SyncDriverAdapterProtocol["Connection"]):
class SqliteDriver(
SQLTranslatorMixin["Connection"],
SyncDriverAdapterProtocol["Connection"],
):
"""SQLite Sync Driver Adapter."""

connection: "Connection"
Expand Down
Loading