diff --git a/README.md b/README.md index 7cc368243..41d11de0e 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,144 @@ SQLSpec is an experimental Python library designed to streamline and modernize y SQLSpec is a work in progress. While it offers a solid foundation for modern SQL interactions, it does not yet include every feature you might find in a mature ORM or database toolkit. The focus is on building a robust, flexible core that can be extended over time. +## Examples + +We've talked about what SQLSpec is not, so let's look at what it can do. + +### Basic Example + +### Multiple Database Engines + +### Examples, Bells and Whistles + +Let's take a look at a few alternate examples that leverage DuckDB and Litestar. + +#### DuckDB LLM + +This is a quick implementation using some of the built in Secret and Extension management features of SQLSpec's DuckDB integration. + +It allows you to communicate with any compatible OpenAPI conversations endpoint (such as Ollama). This examples: + +- auto installs the `open_prompt` DuckDB extensions +- automatically creates the correct `open_prompt` comptaible secret required to use the extension + +```py +# /// script +# dependencies = [ +# "sqlspec[duckdb,performance]", +# ] +# /// +import os + +from sqlspec import SQLSpec +from sqlspec.adapters.duckdb import DuckDBConfig + +sql = SQLSpec() +etl_config = sql.add_config( + DuckDBConfig( + extensions=[{"name": "open_prompt"}], + secrets=[ + { + "secret_type": "open_prompt", + "name": "open_prompt", + "value": { + "api_url": "http://127.0.0.1:11434/v1/chat/completions", + "model_name": "gemma3:1b", + "api_timeout": "120", + }, + } + ], + ) +) +with sql.provide_session(etl_config) as session: + result = session.select_one("SELECT generate_embedding('example text')") + print(result) +``` + +#### DuckDB Gemini Embeddings + +In this example, we are again using DuckDB. However, we are going to use the built in to call the Google Gemini embeddings service directly from the database. + +This example will + +- auto installs the `http_client` and `vss` (vector similarity search) DuckDB extensions +- when a connection is created, it ensures that the `generate_embeddings` macro exists in the DuckDB database. +- Execute a simple query to call the Google API + +```py +# /// script +# dependencies = [ +# "sqlspec[duckdb,performance]", +# ] +# /// +import os + +from sqlspec import SQLSpec +from sqlspec.adapters.duckdb import DuckDBConfig + +EMBEDDING_MODEL = "gemini-embedding-exp-03-07" +GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") +API_URL = ( + f"https://generativelanguage.googleapis.com/v1beta/models/{EMBEDDING_MODEL}:embedContent?key=${GOOGLE_API_KEY}" +) + +sql = SQLSpec() +etl_config = sql.add_config( + DuckDBConfig( + extensions=[{"name": "vss"}, {"name": "http_client"}], + on_connection_create=lambda connection: connection.execute(f""" + CREATE IF NOT EXISTS MACRO generate_embedding(q) AS ( + WITH __request AS ( + SELECT http_post( + '{API_URL}', + headers => MAP {{ + 'accept': 'application/json', + }}, + params => MAP {{ + 'model': 'models/{EMBEDDING_MODEL}', + 'parts': [{{ 'text': q }}], + 'taskType': 'SEMANTIC_SIMILARITY' + }} + ) AS response + ) + SELECT * + FROM __request, + ); + """), + ) +) +``` + +#### Basic Litestar Integration + +In this example we are going to demonstrate how to create a basic configuration that integrates into Litestar. + +```py +# /// script +# dependencies = [ +# "sqlspec[aiosqlite]", +# "litestar[standard]", +# ] +# /// + +from aiosqlite import Connection +from litestar import Litestar, get + +from sqlspec.adapters.aiosqlite import AiosqliteConfig, AiosqliteDriver +from sqlspec.extensions.litestar import SQLSpec + + +@get("/") +async def simple_sqlite(db_session: AiosqliteDriver) -> dict[str, str]: + return await db_session.select_one("SELECT 'Hello, world!' AS greeting") + + +sqlspec = SQLSpec(config=DatabaseConfig( + config=[AiosqliteConfig(), commit_mode="autocommit")], +) +app = Litestar(route_handlers=[simple_sqlite], plugins=[sqlspec]) +``` + ## Inspiration and Future Direction SQLSpec originally drew inspiration from features found in the `aiosql` library. This is a great library for working with and executed SQL stored in files. It's unclear how much of an overlap there will be between the two libraries, but it's possible that some features will be contributed back to `aiosql` where appropriate. diff --git a/docs/PYPI_README.md b/docs/PYPI_README.md index 7cc368243..41d11de0e 100644 --- a/docs/PYPI_README.md +++ b/docs/PYPI_README.md @@ -22,6 +22,144 @@ SQLSpec is an experimental Python library designed to streamline and modernize y SQLSpec is a work in progress. While it offers a solid foundation for modern SQL interactions, it does not yet include every feature you might find in a mature ORM or database toolkit. The focus is on building a robust, flexible core that can be extended over time. +## Examples + +We've talked about what SQLSpec is not, so let's look at what it can do. + +### Basic Example + +### Multiple Database Engines + +### Examples, Bells and Whistles + +Let's take a look at a few alternate examples that leverage DuckDB and Litestar. + +#### DuckDB LLM + +This is a quick implementation using some of the built in Secret and Extension management features of SQLSpec's DuckDB integration. + +It allows you to communicate with any compatible OpenAPI conversations endpoint (such as Ollama). This examples: + +- auto installs the `open_prompt` DuckDB extensions +- automatically creates the correct `open_prompt` comptaible secret required to use the extension + +```py +# /// script +# dependencies = [ +# "sqlspec[duckdb,performance]", +# ] +# /// +import os + +from sqlspec import SQLSpec +from sqlspec.adapters.duckdb import DuckDBConfig + +sql = SQLSpec() +etl_config = sql.add_config( + DuckDBConfig( + extensions=[{"name": "open_prompt"}], + secrets=[ + { + "secret_type": "open_prompt", + "name": "open_prompt", + "value": { + "api_url": "http://127.0.0.1:11434/v1/chat/completions", + "model_name": "gemma3:1b", + "api_timeout": "120", + }, + } + ], + ) +) +with sql.provide_session(etl_config) as session: + result = session.select_one("SELECT generate_embedding('example text')") + print(result) +``` + +#### DuckDB Gemini Embeddings + +In this example, we are again using DuckDB. However, we are going to use the built in to call the Google Gemini embeddings service directly from the database. + +This example will + +- auto installs the `http_client` and `vss` (vector similarity search) DuckDB extensions +- when a connection is created, it ensures that the `generate_embeddings` macro exists in the DuckDB database. +- Execute a simple query to call the Google API + +```py +# /// script +# dependencies = [ +# "sqlspec[duckdb,performance]", +# ] +# /// +import os + +from sqlspec import SQLSpec +from sqlspec.adapters.duckdb import DuckDBConfig + +EMBEDDING_MODEL = "gemini-embedding-exp-03-07" +GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") +API_URL = ( + f"https://generativelanguage.googleapis.com/v1beta/models/{EMBEDDING_MODEL}:embedContent?key=${GOOGLE_API_KEY}" +) + +sql = SQLSpec() +etl_config = sql.add_config( + DuckDBConfig( + extensions=[{"name": "vss"}, {"name": "http_client"}], + on_connection_create=lambda connection: connection.execute(f""" + CREATE IF NOT EXISTS MACRO generate_embedding(q) AS ( + WITH __request AS ( + SELECT http_post( + '{API_URL}', + headers => MAP {{ + 'accept': 'application/json', + }}, + params => MAP {{ + 'model': 'models/{EMBEDDING_MODEL}', + 'parts': [{{ 'text': q }}], + 'taskType': 'SEMANTIC_SIMILARITY' + }} + ) AS response + ) + SELECT * + FROM __request, + ); + """), + ) +) +``` + +#### Basic Litestar Integration + +In this example we are going to demonstrate how to create a basic configuration that integrates into Litestar. + +```py +# /// script +# dependencies = [ +# "sqlspec[aiosqlite]", +# "litestar[standard]", +# ] +# /// + +from aiosqlite import Connection +from litestar import Litestar, get + +from sqlspec.adapters.aiosqlite import AiosqliteConfig, AiosqliteDriver +from sqlspec.extensions.litestar import SQLSpec + + +@get("/") +async def simple_sqlite(db_session: AiosqliteDriver) -> dict[str, str]: + return await db_session.select_one("SELECT 'Hello, world!' AS greeting") + + +sqlspec = SQLSpec(config=DatabaseConfig( + config=[AiosqliteConfig(), commit_mode="autocommit")], +) +app = Litestar(route_handlers=[simple_sqlite], plugins=[sqlspec]) +``` + ## Inspiration and Future Direction SQLSpec originally drew inspiration from features found in the `aiosql` library. This is a great library for working with and executed SQL stored in files. It's unclear how much of an overlap there will be between the two libraries, but it's possible that some features will be contributed back to `aiosql` where appropriate. diff --git a/docs/examples/litestar_duckllm.py b/docs/examples/litestar_duckllm.py index 437362eb6..e26b992eb 100644 --- a/docs/examples/litestar_duckllm.py +++ b/docs/examples/litestar_duckllm.py @@ -13,11 +13,11 @@ # ] # /// -from duckdb import DuckDBPyConnection from litestar import Litestar, post from msgspec import Struct -from sqlspec.adapters.duckdb import DuckDB +from sqlspec.adapters.duckdb import DuckDBConfig +from sqlspec.adapters.duckdb.driver import DuckDBDriver from sqlspec.extensions.litestar import SQLSpec @@ -26,13 +26,12 @@ class ChatMessage(Struct): @post("/chat", sync_to_thread=True) -def duckllm_chat(db_connection: DuckDBPyConnection, data: ChatMessage) -> ChatMessage: - result = db_connection.execute("SELECT open_prompt(?)", (data.message,)).fetchall() - return ChatMessage(message=result[0][0]) +def duckllm_chat(db_session: DuckDBDriver, data: ChatMessage) -> ChatMessage: + return db_session.select_one("SELECT open_prompt(?)", data.message, schema_type=ChatMessage) sqlspec = SQLSpec( - config=DuckDB( + config=DuckDBConfig( extensions=[{"name": "open_prompt"}], secrets=[ { diff --git a/docs/examples/litestar_multi_db.py b/docs/examples/litestar_multi_db.py index b219903d5..f9469320c 100644 --- a/docs/examples/litestar_multi_db.py +++ b/docs/examples/litestar_multi_db.py @@ -1,34 +1,56 @@ -from aiosqlite import Connection -from duckdb import DuckDBPyConnection +"""Litestar Multi DB + +This example demonstrates how to use multiple databases in a Litestar application. + +The example uses the `SQLSpec` extension to create a connection to a SQLite (via `aiosqlite`) and DuckDB database. + +The DuckDB database also demonstrates how to use the plugin loader and `secrets` configuration manager built into SQLSpec. +""" +# /// script +# dependencies = [ +# "sqlspec[aiosqlite,duckdb]", +# "litestar[standard]", +# ] +# /// + from litestar import Litestar, get -from sqlspec.adapters.aiosqlite import Aiosqlite -from sqlspec.adapters.duckdb import DuckDB +from sqlspec.adapters.aiosqlite import AiosqliteConfig, AiosqliteDriver +from sqlspec.adapters.duckdb import DuckDBConfig, DuckDBDriver from sqlspec.extensions.litestar import DatabaseConfig, SQLSpec @get("/test", sync_to_thread=True) -def simple_select(etl_connection: DuckDBPyConnection) -> dict[str, str]: - result = etl_connection.execute("SELECT 'Hello, world!' AS greeting").fetchall() - return {"greeting": result[0][0]} +def simple_select(etl_session: DuckDBDriver) -> dict[str, str]: + result = etl_session.select_one("SELECT 'Hello, ETL world!' AS greeting") + return {"greeting": result["greeting"]} @get("/") -async def simple_sqlite(db_connection: Connection) -> dict[str, str]: - result = await db_connection.execute_fetchall("SELECT 'Hello, world!' AS greeting") - return {"greeting": result[0][0]} # type: ignore +async def simple_sqlite(db_session: AiosqliteDriver) -> dict[str, str]: + return await db_session.select_one("SELECT 'Hello, world!' AS greeting") sqlspec = SQLSpec( config=[ - DatabaseConfig(config=Aiosqlite(), commit_mode="autocommit"), + DatabaseConfig(config=AiosqliteConfig(), commit_mode="autocommit"), DatabaseConfig( - config=DuckDB( + config=DuckDBConfig( extensions=[{"name": "vss", "force_install": True}], secrets=[{"secret_type": "s3", "name": "s3_secret", "value": {"key_id": "abcd"}}], ), connection_key="etl_connection", + session_key="etl_session", ), ], ) app = Litestar(route_handlers=[simple_sqlite, simple_select], plugins=[sqlspec]) + +if __name__ == "__main__": + import os + + from litestar.cli import litestar_group + + os.environ["LITESTAR_APP"] = "docs.examples.litestar_multi_db:app" + + litestar_group() diff --git a/docs/examples/litestar_single_db.py b/docs/examples/litestar_single_db.py index 69260b723..b58030773 100644 --- a/docs/examples/litestar_single_db.py +++ b/docs/examples/litestar_single_db.py @@ -1,7 +1,21 @@ +"""Litestar Single DB + +This example demonstrates how to use a single database in a Litestar application. + +This examples hows how to get the raw connection object from the SQLSpec plugin. +""" + +# /// script +# dependencies = [ +# "sqlspec[aiosqlite]", +# "litestar[standard]", +# ] +# /// + from aiosqlite import Connection from litestar import Litestar, get -from sqlspec.adapters.aiosqlite import Aiosqlite +from sqlspec.adapters.aiosqlite import AiosqliteConfig from sqlspec.extensions.litestar import SQLSpec @@ -16,5 +30,5 @@ async def simple_sqlite(db_connection: Connection) -> dict[str, str]: return {"greeting": result[0][0]} # type: ignore -sqlspec = SQLSpec(config=Aiosqlite()) +sqlspec = SQLSpec(config=AiosqliteConfig()) app = Litestar(route_handlers=[simple_sqlite], plugins=[sqlspec]) diff --git a/docs/examples/litestar_gemini.py b/docs/examples/standalone_duckdb.py similarity index 71% rename from docs/examples/litestar_gemini.py rename to docs/examples/standalone_duckdb.py index 24064f7ce..a90275fe5 100644 --- a/docs/examples/litestar_gemini.py +++ b/docs/examples/standalone_duckdb.py @@ -1,22 +1,18 @@ -"""Litestar DuckLLM +"""Generating embeddings with Gemini -This example demonstrates how to use the Litestar framework with the DuckLLM extension. - -The example uses the `SQLSpec` extension to create a connection to the DuckDB database. -The `DuckDB` adapter is used to create a connection to the database. +This example demonstrates how to generate embeddings with Gemini using only DuckDB and the HTTP client extension. """ # /// script # dependencies = [ # "sqlspec[duckdb,performance]", -# "litestar[standard]", # ] # /// import os from sqlspec import SQLSpec -from sqlspec.adapters.duckdb import DuckDB +from sqlspec.adapters.duckdb import DuckDBConfig EMBEDDING_MODEL = "gemini-embedding-exp-03-07" GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") @@ -26,7 +22,7 @@ sql = SQLSpec() etl_config = sql.add_config( - DuckDB( + DuckDBConfig( extensions=[{"name": "vss"}, {"name": "http_client"}], on_connection_create=lambda connection: connection.execute(f""" CREATE IF NOT EXISTS MACRO generate_embedding(q) AS ( @@ -52,6 +48,6 @@ if __name__ == "__main__": - with sql.get_connection(etl_config) as connection: - result = connection.execute("SELECT generate_embedding('example text')") - print(result.fetchall()) + with sql.provide_session(etl_config) as session: + result = session.select_one("SELECT generate_embedding('example text')") + print(result) diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index fc9e5c4d9..6b3f7ec2b 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -1,9 +1,9 @@ import contextlib import logging import re -from collections.abc import Generator +from collections.abc import Generator, Sequence from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast, overload from adbc_driver_manager.dbapi import Connection, Cursor @@ -160,6 +160,28 @@ def _process_sql_params( stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) return stmt.process() + @overload + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Sequence[dict[str, Any]]": ... + @overload + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Sequence[ModelDTOT]": ... def select( self, sql: str, @@ -169,7 +191,7 @@ def select( connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "list[Union[ModelDTOT, dict[str, Any]]]": + ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. Returns: @@ -189,6 +211,28 @@ def select( return [cast("ModelDTOT", schema_type(**dict(zip(column_names, row)))) for row in results] # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] return [dict(zip(column_names, row)) for row in results] # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] + @overload + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... def select_one( self, sql: str, @@ -215,6 +259,28 @@ def select_one( return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] return schema_type(**dict(zip(column_names, result))) # type: ignore[return-value] + @overload + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[dict[str, Any]]": ... + @overload + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Optional[ModelDTOT]": ... def select_one_or_none( self, sql: str, @@ -242,6 +308,28 @@ def select_one_or_none( return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] return schema_type(**dict(zip(column_names, result))) # type: ignore[return-value] + @overload + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Any": ... + @overload + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "T": ... def select_value( self, sql: str, @@ -267,6 +355,28 @@ def select_value( return result[0] # pyright: ignore[reportUnknownVariableType] return schema_type(result[0]) # type: ignore[call-arg] + @overload + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[Any]": ... + @overload + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "Optional[T]": ... def select_value_or_none( self, sql: str, @@ -314,6 +424,28 @@ def insert_update_delete( cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return cursor.rowcount if hasattr(cursor, "rowcount") else -1 + @overload + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... def insert_update_delete_returning( self, sql: str, diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index b61ee3b2e..fc96a99fc 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -1,10 +1,10 @@ from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload from sqlspec.base import AsyncDriverAdapterProtocol, T if TYPE_CHECKING: - from collections.abc import AsyncGenerator + from collections.abc import AsyncGenerator, Sequence from aiosqlite import Connection, Cursor @@ -34,6 +34,29 @@ async def _with_cursor(self, connection: "Connection") -> "AsyncGenerator[Cursor finally: await cursor.close() + # --- Public API Methods --- # + @overload + async def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Sequence[dict[str, Any]]": ... + @overload + async def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Sequence[ModelDTOT]": ... async def select( self, sql: str, @@ -43,7 +66,7 @@ async def select( connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "list[Union[ModelDTOT, dict[str, Any]]]": + ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. Returns: @@ -61,6 +84,28 @@ async def select( return [dict(zip(column_names, row)) for row in results] # pyright: ignore[reportUnknownArgumentType] return [cast("ModelDTOT", schema_type(**dict(zip(column_names, row)))) for row in results] # pyright: ignore[reportUnknownArgumentType] + @overload + async def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + async def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... async def select_one( self, sql: str, @@ -87,6 +132,28 @@ async def select_one( return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) # pyright: ignore[reportUnknownArgumentType] + @overload + async def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[dict[str, Any]]": ... + @overload + async def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Optional[ModelDTOT]": ... async def select_one_or_none( self, sql: str, @@ -114,6 +181,28 @@ async def select_one_or_none( return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) # pyright: ignore[reportUnknownArgumentType] + @overload + async def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Any": ... + @overload + async def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "T": ... async def select_value( self, sql: str, @@ -139,6 +228,28 @@ async def select_value( return result[0] return schema_type(result[0]) # type: ignore[call-arg] + @overload + async def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[Any]": ... + @overload + async def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "Optional[T]": ... async def select_value_or_none( self, sql: str, @@ -186,6 +297,28 @@ async def insert_update_delete( await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return cursor.rowcount if hasattr(cursor, "rowcount") else -1 # pyright: ignore[reportUnknownVariableType, reportGeneralTypeIssues] + @overload + async def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + async def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... async def insert_update_delete_returning( self, sql: str, diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index 2ffa02be2..19552c6e4 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -1,7 +1,7 @@ # type: ignore -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Sequence from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload from sqlspec.base import AsyncDriverAdapterProtocol, T @@ -36,6 +36,29 @@ async def _with_cursor(connection: "Connection") -> AsyncGenerator["Cursor", Non finally: await cursor.close() + # --- Public API Methods --- # + @overload + async def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Sequence[dict[str, Any]]": ... + @overload + async def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Sequence[ModelDTOT]": ... async def select( self, sql: str, @@ -45,7 +68,7 @@ async def select( connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "list[Union[ModelDTOT, dict[str, Any]]]": + ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. Returns: @@ -63,6 +86,28 @@ async def select( return [dict(zip(column_names, row)) for row in results] return [schema_type(**dict(zip(column_names, row))) for row in results] + @overload + async def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + async def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... async def select_one( self, sql: str, @@ -89,6 +134,28 @@ async def select_one( return dict(zip(column_names, result)) return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) + @overload + async def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[dict[str, Any]]": ... + @overload + async def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Optional[ModelDTOT]": ... async def select_one_or_none( self, sql: str, @@ -116,6 +183,28 @@ async def select_one_or_none( return dict(zip(column_names, result)) return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) + @overload + async def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Any": ... + @overload + async def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "T": ... async def select_value( self, sql: str, @@ -144,6 +233,28 @@ async def select_value( return schema_type(value) # type: ignore[call-arg] return value + @overload + async def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[Any]": ... + @overload + async def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "Optional[T]": ... async def select_value_or_none( self, sql: str, @@ -195,6 +306,28 @@ async def insert_update_delete( await cursor.execute(sql, parameters) return cursor.rowcount + @overload + async def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + async def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... async def insert_update_delete_returning( self, sql: str, diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index 28fa4b63d..e5b19564d 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -1,6 +1,6 @@ import logging import re -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload from asyncpg import Connection from typing_extensions import TypeAlias @@ -10,6 +10,8 @@ from sqlspec.statement import PARAM_REGEX, SQLStatement if TYPE_CHECKING: + from collections.abc import Sequence + from asyncpg.connection import Connection from asyncpg.pool import PoolConnectionProxy @@ -196,6 +198,28 @@ def _process_sql_params( # noqa: C901, PLR0912, PLR0915 # No parameters provided and none found in SQL, return original SQL from SQLStatement and empty tuple return sql, () # asyncpg expects a sequence, even if empty + @overload + async def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncpgConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Sequence[dict[str, Any]]": ... + @overload + async def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncpgConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Sequence[ModelDTOT]": ... async def select( self, sql: str, @@ -205,7 +229,7 @@ async def select( connection: Optional["AsyncpgConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "list[Union[ModelDTOT, dict[str, Any]]]": + ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. Args: @@ -229,6 +253,28 @@ async def select( return [dict(row.items()) for row in results] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] return [cast("ModelDTOT", schema_type(**dict(row.items()))) for row in results] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + @overload + async def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncpgConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + async def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncpgConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... async def select_one( self, sql: str, @@ -262,6 +308,28 @@ async def select_one( return dict(result.items()) # type: ignore[attr-defined] return cast("ModelDTOT", schema_type(**dict(result.items()))) # type: ignore[attr-defined] + @overload + async def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncpgConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[dict[str, Any]]": ... + @overload + async def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncpgConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Optional[ModelDTOT]": ... async def select_one_or_none( self, sql: str, @@ -295,6 +363,28 @@ async def select_one_or_none( return dict(result.items()) return cast("ModelDTOT", schema_type(**dict(result.items()))) + @overload + async def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncpgConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Any": ... + @overload + async def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncpgConnection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "T": ... async def select_value( self, sql: str, @@ -326,6 +416,28 @@ async def select_value( return result return schema_type(result) # type: ignore[call-arg] + @overload + async def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncpgConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[Any]": ... + @overload + async def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncpgConnection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "Optional[T]": ... async def select_value_or_none( self, sql: str, @@ -381,6 +493,28 @@ async def insert_update_delete( except (ValueError, IndexError, AttributeError): return -1 # Fallback if we can't parse the status + @overload + async def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncpgConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + async def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncpgConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... async def insert_update_delete_returning( self, sql: str, diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index 18a4abec0..dd9c6eb1e 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -1,10 +1,10 @@ from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload from sqlspec.base import SyncArrowBulkOperationsMixin, SyncDriverAdapterProtocol, T if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Generator, Sequence from duckdb import DuckDBPyConnection @@ -42,17 +42,38 @@ def _with_cursor(self, connection: "DuckDBPyConnection") -> "Generator[DuckDBPyC yield connection # --- Public API Methods --- # - + @overload def select( self, sql: str, - parameters: Optional["StatementParameterType"] = None, + parameters: "Optional[StatementParameterType]" = None, /, *, - connection: Optional["DuckDBPyConnection"] = None, + connection: "Optional[DuckDBPyConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Sequence[dict[str, Any]]": ... + @overload + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[DuckDBPyConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Sequence[ModelDTOT]": ... + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[DuckDBPyConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "list[Union[ModelDTOT, dict[str, Any]]]": + ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: @@ -67,6 +88,28 @@ def select( return [cast("ModelDTOT", schema_type(**dict(zip(column_names, row)))) for row in results] # pyright: ignore[reportUnknownArgumentType] return [dict(zip(column_names, row)) for row in results] # pyright: ignore[reportUnknownArgumentType] + @overload + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[DuckDBPyConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[DuckDBPyConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... def select_one( self, sql: str, @@ -90,6 +133,28 @@ def select_one( # Always return dictionaries return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] + @overload + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[DuckDBPyConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[dict[str, Any]]": ... + @overload + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[DuckDBPyConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Optional[ModelDTOT]": ... def select_one_or_none( self, sql: str, @@ -113,6 +178,28 @@ def select_one_or_none( return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) # pyright: ignore[reportUnknownArgumentType] return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] + @overload + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[DuckDBPyConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Any": ... + @overload + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[DuckDBPyConnection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "T": ... def select_value( self, sql: str, @@ -133,6 +220,28 @@ def select_value( return result[0] # pyright: ignore return schema_type(result[0]) # type: ignore[call-arg] + @overload + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[DuckDBPyConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[Any]": ... + @overload + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[DuckDBPyConnection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "Optional[T]": ... def select_value_or_none( self, sql: str, @@ -169,23 +278,44 @@ def insert_update_delete( cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return getattr(cursor, "rowcount", -1) # pyright: ignore[reportUnknownMemberType] + @overload def insert_update_delete_returning( self, sql: str, - parameters: Optional["StatementParameterType"] = None, + parameters: "Optional[StatementParameterType]" = None, /, *, - connection: Optional["DuckDBPyConnection"] = None, + connection: "Optional[DuckDBPyConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[DuckDBPyConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[DuckDBPyConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": + ) -> "Union[ModelDTOT, dict[str, Any]]": connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - if not result: - return None # pyright: ignore[reportUnknownArgumentType] + result = self.check_not_found(result) # pyright: ignore column_names = [col[0] for col in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] if schema_type is not None: return cast("ModelDTOT", schema_type(**dict(zip(column_names, result[0])))) # pyright: ignore[reportUnknownArgumentType] diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 40e9cabbd..32cdb2e9e 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -1,5 +1,5 @@ from contextlib import asynccontextmanager, contextmanager -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload from sqlspec.base import ( AsyncArrowBulkOperationsMixin, @@ -11,7 +11,7 @@ from sqlspec.typing import ArrowTable, StatementParameterType if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Generator + from collections.abc import AsyncGenerator, Generator, Sequence from oracledb import AsyncConnection, AsyncCursor, Connection, Cursor @@ -39,6 +39,29 @@ def _with_cursor(connection: "Connection") -> "Generator[Cursor, None, None]": finally: cursor.close() + # --- Public API Methods --- # + @overload + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Sequence[dict[str, Any]]": ... + @overload + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Sequence[ModelDTOT]": ... def select( self, sql: str, @@ -48,7 +71,7 @@ def select( connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "list[Union[ModelDTOT, dict[str, Any]]]": + ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. Args: @@ -76,6 +99,28 @@ def select( return [dict(zip(column_names, row)) for row in results] # pyright: ignore + @overload + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... def select_one( self, sql: str, @@ -114,6 +159,28 @@ def select_one( # Always return dictionaries return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] + @overload + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[dict[str, Any]]": ... + @overload + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Optional[ModelDTOT]": ... def select_one_or_none( self, sql: str, @@ -147,6 +214,28 @@ def select_one_or_none( # Always return dictionaries return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] + @overload + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Any": ... + @overload + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "T": ... def select_value( self, sql: str, @@ -174,6 +263,28 @@ def select_value( return result[0] # pyright: ignore[reportUnknownArgumentType] return schema_type(result[0]) # type: ignore[call-arg] + @overload + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[Any]": ... + @overload + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "Optional[T]": ... def select_value_or_none( self, sql: str, @@ -224,6 +335,28 @@ def insert_update_delete( cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return cursor.rowcount # pyright: ignore[reportUnknownMemberType] + @overload + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... def insert_update_delete_returning( self, sql: str, @@ -319,6 +452,29 @@ async def _with_cursor(connection: "AsyncConnection") -> "AsyncGenerator[AsyncCu finally: cursor.close() + # --- Public API Methods --- # + @overload + async def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Sequence[dict[str, Any]]": ... + @overload + async def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Sequence[ModelDTOT]": ... async def select( self, sql: str, @@ -328,7 +484,7 @@ async def select( connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "list[Union[ModelDTOT, dict[str, Any]]]": + ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. Returns: @@ -350,6 +506,28 @@ async def select( return [dict(zip(column_names, row)) for row in results] # pyright: ignore + @overload + async def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + async def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... async def select_one( self, sql: str, @@ -380,6 +558,28 @@ async def select_one( # Always return dictionaries return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] + @overload + async def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[dict[str, Any]]": ... + @overload + async def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Optional[ModelDTOT]": ... async def select_one_or_none( self, sql: str, @@ -413,6 +613,28 @@ async def select_one_or_none( # Always return dictionaries return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] + @overload + async def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Any": ... + @overload + async def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "T": ... async def select_value( self, sql: str, @@ -440,6 +662,28 @@ async def select_value( return result[0] # pyright: ignore[reportUnknownArgumentType] return schema_type(result[0]) # type: ignore[call-arg] + @overload + async def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[Any]": ... + @overload + async def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "Optional[T]": ... async def select_value_or_none( self, sql: str, @@ -490,6 +734,28 @@ async def insert_update_delete( await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return cursor.rowcount # pyright: ignore[reportUnknownMemberType] + @overload + async def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + async def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... async def insert_update_delete_returning( self, sql: str, diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index b88c3d3bf..193089714 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -207,25 +207,17 @@ async def _create() -> "ConnectionPool": def create_connection(self) -> "Awaitable[Connection]": """Create and return a new, standalone psqlpy connection using the configured parameters. - Note: This method is not supported by the psqlpy adapter as connection - creation is primarily handled via the ConnectionPool. - Use `provide_connection` or `provide_session` for pooled connections. - Returns: An awaitable that resolves to a new Connection instance. - - Raises: - NotImplementedError: This method is not implemented for psqlpy. """ async def _create() -> "Connection": - # psqlpy does not seem to offer a public API for creating - # standalone async connections easily outside the pool context. - msg = ( - "Creating standalone connections is not directly supported by the psqlpy adapter. " - "Please use the pool via `provide_connection` or `provide_session`." - ) - raise NotImplementedError(msg) + try: + async with self.provide_connection() as conn: + return conn + except Exception as e: + msg = f"Could not configure the psqlpy connection. Error: {e!s}" + raise ImproperConfigurationError(msg) from e return _create() diff --git a/sqlspec/adapters/psqlpy/driver.py b/sqlspec/adapters/psqlpy/driver.py index a8272e486..0b3f82a2c 100644 --- a/sqlspec/adapters/psqlpy/driver.py +++ b/sqlspec/adapters/psqlpy/driver.py @@ -3,7 +3,7 @@ import logging import re -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload from psqlpy.exceptions import RustPSQLDriverPyBaseError @@ -12,6 +12,8 @@ from sqlspec.statement import PARAM_REGEX, SQLStatement if TYPE_CHECKING: + from collections.abc import Sequence + from psqlpy import Connection, QueryResult from sqlspec.typing import ModelDTOT, StatementParameterType @@ -52,6 +54,17 @@ def _process_sql_params( psqlpy uses $1, $2 style parameters natively. This method converts '?' (tuple/list) and ':name' (dict) styles to $n. It relies on SQLStatement for initial parameter validation and merging. + + Args: + sql: The SQL to process. + parameters: The parameters to process. + kwargs: Additional keyword arguments. + + Raises: + SQLParsingError: If the SQL is invalid. + + Returns: + A tuple of the processed SQL and parameters. """ stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) sql, parameters = stmt.process() @@ -154,6 +167,29 @@ def _process_sql_params( return sql, () + # --- Public API Methods --- # + @overload + async def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Sequence[dict[str, Any]]": ... + @overload + async def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Sequence[ModelDTOT]": ... async def select( self, sql: str, @@ -163,7 +199,7 @@ async def select( connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "list[Union[ModelDTOT, dict[str, Any]]]": + ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) parameters = parameters or [] # psqlpy expects a list/tuple @@ -171,9 +207,31 @@ async def select( results: QueryResult = await connection.fetch(sql, parameters=parameters) if schema_type is None: - return cast("list[dict[str, Any]]", results.result()) # type: ignore[return-value] + return cast("list[dict[str, Any]]", results.result()) return results.as_class(as_class=schema_type) + @overload + async def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + async def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... async def select_one( self, sql: str, @@ -195,6 +253,28 @@ async def select_one( return cast("dict[str, Any]", result[0]) # type: ignore[index] return result.as_class(as_class=schema_type)[0] + @overload + async def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[dict[str, Any]]": ... + @overload + async def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Optional[ModelDTOT]": ... async def select_one_or_none( self, sql: str, @@ -220,6 +300,28 @@ async def select_one_or_none( return None return cast("ModelDTOT", result[0]) # type: ignore[index] + @overload + async def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Any": ... + @overload + async def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "T": ... async def select_value( self, sql: str, @@ -240,6 +342,28 @@ async def select_value( return value return schema_type(value) # type: ignore[call-arg] + @overload + async def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[Any]": ... + @overload + async def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "Optional[T]": ... async def select_value_or_none( self, sql: str, @@ -282,6 +406,28 @@ async def insert_update_delete( # if no error was raised return 1 + @overload + async def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + async def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... async def insert_update_delete_returning( self, sql: str, diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index 345593c02..07808cbc3 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -1,6 +1,6 @@ import logging from contextlib import asynccontextmanager, contextmanager -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload from psycopg.rows import dict_row @@ -9,7 +9,7 @@ from sqlspec.statement import PARAM_REGEX, SQLStatement if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Generator + from collections.abc import AsyncGenerator, Generator, Sequence from psycopg import AsyncConnection, Connection @@ -20,7 +20,63 @@ __all__ = ("PsycopgAsyncDriver", "PsycopgSyncDriver") -class PsycopgSyncDriver(SyncDriverAdapterProtocol["Connection"]): +class PsycopgParameterParser: + dialect: str + + def _process_sql_params( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + **kwargs: Any, + ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + """Process SQL and parameters, converting :name -> %(name)s if needed.""" + stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) + processed_sql, processed_params = stmt.process() + + if isinstance(processed_params, dict): + parameter_dict = processed_params + processed_sql_parts: list[str] = [] + last_end = 0 + found_params_regex: list[str] = [] + + for match in PARAM_REGEX.finditer(processed_sql): + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue + + if match.group("var_name"): + var_name = match.group("var_name") + found_params_regex.append(var_name) + start = match.start("var_name") - 1 + end = match.end("var_name") + + if var_name not in parameter_dict: + msg = ( + f"Named parameter ':{var_name}' found in SQL but missing from processed parameters. " + f"Processed SQL: {processed_sql}" + ) + raise SQLParsingError(msg) + + processed_sql_parts.extend((processed_sql[last_end:start], f"%({var_name})s")) + last_end = end + + processed_sql_parts.append(processed_sql[last_end:]) + final_sql = "".join(processed_sql_parts) + + if not found_params_regex and parameter_dict: + logger.warning( + "Dict params provided (%s), but no :name placeholders found. SQL: %s", + list(parameter_dict.keys()), + processed_sql, + ) + return processed_sql, parameter_dict + + return final_sql, parameter_dict + + return processed_sql, processed_params + + +class PsycopgSyncDriver(PsycopgParameterParser, SyncDriverAdapterProtocol["Connection"]): """Psycopg Sync Driver Adapter.""" connection: "Connection" @@ -90,6 +146,29 @@ def _with_cursor(connection: "Connection") -> "Generator[Any, None, None]": finally: cursor.close() + # --- Public API Methods --- # + @overload + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Sequence[dict[str, Any]]": ... + @overload + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Sequence[ModelDTOT]": ... def select( self, sql: str, @@ -99,7 +178,7 @@ def select( schema_type: "Optional[type[ModelDTOT]]" = None, connection: "Optional[Connection]" = None, **kwargs: Any, - ) -> "list[Union[ModelDTOT, dict[str, Any]]]": + ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. Returns: @@ -117,6 +196,28 @@ def select( return [cast("ModelDTOT", schema_type(**row)) for row in results] # pyright: ignore[reportUnknownArgumentType] return [cast("dict[str,Any]", row) for row in results] # pyright: ignore[reportUnknownArgumentType] + @overload + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... def select_one( self, sql: str, @@ -142,6 +243,28 @@ def select_one( return cast("ModelDTOT", schema_type(**cast("dict[str,Any]", row))) return cast("dict[str,Any]", row) + @overload + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[dict[str, Any]]": ... + @overload + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Optional[ModelDTOT]": ... def select_one_or_none( self, sql: str, @@ -168,6 +291,28 @@ def select_one_or_none( return cast("ModelDTOT", schema_type(**cast("dict[str,Any]", row))) return cast("dict[str,Any]", row) + @overload + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Any": ... + @overload + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "T": ... def select_value( self, sql: str, @@ -195,6 +340,28 @@ def select_value( return schema_type(val) # type: ignore[call-arg] return val + @overload + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[Any]": ... + @overload + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "Optional[T]": ... def select_value_or_none( self, sql: str, @@ -244,6 +411,28 @@ def insert_update_delete( cursor.execute(sql, parameters) return getattr(cursor, "rowcount", -1) # pyright: ignore[reportUnknownMemberType] + @overload + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... def insert_update_delete_returning( self, sql: str, @@ -293,7 +482,7 @@ def execute_script( return str(cursor.statusmessage) if cursor.statusmessage is not None else "DONE" -class PsycopgAsyncDriver(AsyncDriverAdapterProtocol["AsyncConnection"]): +class PsycopgAsyncDriver(PsycopgParameterParser, AsyncDriverAdapterProtocol["AsyncConnection"]): """Psycopg Async Driver Adapter.""" connection: "AsyncConnection" @@ -302,58 +491,6 @@ class PsycopgAsyncDriver(AsyncDriverAdapterProtocol["AsyncConnection"]): def __init__(self, connection: "AsyncConnection") -> None: self.connection = connection - def _process_sql_params( - self, - sql: str, - parameters: "Optional[StatementParameterType]" = None, - /, - **kwargs: Any, - ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL and parameters, converting :name -> %(name)s if needed.""" - stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) - processed_sql, processed_params = stmt.process() - - if isinstance(processed_params, dict): - parameter_dict = processed_params - processed_sql_parts: list[str] = [] - last_end = 0 - found_params_regex: list[str] = [] - - for match in PARAM_REGEX.finditer(processed_sql): - if match.group("dquote") or match.group("squote") or match.group("comment"): - continue - - if match.group("var_name"): - var_name = match.group("var_name") - found_params_regex.append(var_name) - start = match.start("var_name") - 1 - end = match.end("var_name") - - if var_name not in parameter_dict: - msg = ( - f"Named parameter ':{var_name}' found in SQL but missing from processed parameters. " - f"Processed SQL: {processed_sql}" - ) - raise SQLParsingError(msg) - - processed_sql_parts.extend((processed_sql[last_end:start], f"%({var_name})s")) - last_end = end - - processed_sql_parts.append(processed_sql[last_end:]) - final_sql = "".join(processed_sql_parts) - - if not found_params_regex and parameter_dict: - logger.warning( - "Dict params provided (%s), but no :name placeholders found. SQL: %s", - list(parameter_dict.keys()), - processed_sql, - ) - return processed_sql, parameter_dict - - return final_sql, parameter_dict - - return processed_sql, processed_params - @staticmethod @asynccontextmanager async def _with_cursor(connection: "AsyncConnection") -> "AsyncGenerator[Any, None]": @@ -363,6 +500,29 @@ async def _with_cursor(connection: "AsyncConnection") -> "AsyncGenerator[Any, No finally: await cursor.close() + # --- Public API Methods --- # + @overload + async def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Sequence[dict[str, Any]]": ... + @overload + async def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Sequence[ModelDTOT]": ... async def select( self, sql: str, @@ -372,7 +532,7 @@ async def select( connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "list[Union[ModelDTOT, dict[str, Any]]]": + ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. Returns: @@ -391,6 +551,28 @@ async def select( return [cast("ModelDTOT", schema_type(**cast("dict[str,Any]", row))) for row in results] # pyright: ignore[reportUnknownArgumentType] return [cast("dict[str,Any]", row) for row in results] # pyright: ignore[reportUnknownArgumentType] + @overload + async def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + async def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... async def select_one( self, sql: str, @@ -417,6 +599,28 @@ async def select_one( return cast("ModelDTOT", schema_type(**cast("dict[str,Any]", row))) return cast("dict[str,Any]", row) + @overload + async def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[dict[str, Any]]": ... + @overload + async def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Optional[ModelDTOT]": ... async def select_one_or_none( self, sql: str, @@ -444,6 +648,28 @@ async def select_one_or_none( return cast("ModelDTOT", schema_type(**cast("dict[str,Any]", row))) return cast("dict[str,Any]", row) + @overload + async def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Any": ... + @overload + async def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "T": ... async def select_value( self, sql: str, @@ -527,6 +753,28 @@ async def insert_update_delete( rowcount = -1 return rowcount + @overload + async def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + async def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[AsyncConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... async def insert_update_delete_returning( self, sql: str, diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index 7c5916abb..a9342ae00 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -1,11 +1,11 @@ from contextlib import contextmanager from sqlite3 import Connection, Cursor -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload from sqlspec.base import SyncDriverAdapterProtocol, T if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Generator, Sequence from sqlspec.typing import ModelDTOT, StatementParameterType @@ -33,6 +33,29 @@ def _with_cursor(self, connection: "Connection") -> "Generator[Cursor, None, Non finally: cursor.close() + # --- Public API Methods --- # + @overload + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Sequence[dict[str, Any]]": ... + @overload + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Sequence[ModelDTOT]": ... def select( self, sql: str, @@ -42,7 +65,7 @@ def select( connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "list[Union[ModelDTOT, dict[str, Any]]]": + ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. Returns: @@ -63,6 +86,28 @@ def select( return [cast("ModelDTOT", schema_type(**dict(zip(column_names, row)))) for row in results] # pyright: ignore[reportUnknownArgumentType] return [dict(zip(column_names, row)) for row in results] # pyright: ignore[reportUnknownArgumentType] + @overload + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... def select_one( self, sql: str, @@ -92,6 +137,28 @@ def select_one( return dict(zip(column_names, result)) return schema_type(**dict(zip(column_names, result))) # type: ignore[return-value] + @overload + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[dict[str, Any]]": ... + @overload + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Optional[ModelDTOT]": ... def select_one_or_none( self, sql: str, @@ -122,6 +189,28 @@ def select_one_or_none( return dict(zip(column_names, result)) return schema_type(**dict(zip(column_names, result))) # type: ignore[return-value] + @overload + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Any": ... + @overload + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "T": ... def select_value( self, sql: str, @@ -150,6 +239,28 @@ def select_value( return result[0] return schema_type(result[0]) # type: ignore[call-arg] + @overload + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[Any]": ... + @overload + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "Optional[T]": ... def select_value_or_none( self, sql: str, @@ -203,6 +314,28 @@ def insert_update_delete( cursor.execute(sql, parameters) return cursor.rowcount if hasattr(cursor, "rowcount") else -1 + @overload + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + @overload + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[Connection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... def insert_update_delete_returning( self, sql: str, @@ -272,34 +405,3 @@ def execute_script( cursor.execute(sql, parameters) return cast("str", cursor.statusmessage) if hasattr(cursor, "statusmessage") else "DONE" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] - - def execute_script_returning( - self, - sql: str, - parameters: Optional["StatementParameterType"] = None, - /, - *, - connection: Optional["Connection"] = None, - schema_type: "Optional[type[ModelDTOT]]" = None, - **kwargs: Any, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Execute a script and return result. - - Returns: - The first row of results. - """ - connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters, **kwargs) - - with self._with_cursor(connection) as cursor: - if not parameters: - cursor.execute(sql) # pyright: ignore[reportUnknownMemberType] - else: - cursor.execute(sql, parameters) - result = cursor.fetchall() - if len(result) == 0: - return None - column_names = [c[0] for c in cursor.description or []] - if schema_type is not None: - return cast("ModelDTOT", schema_type(**dict(zip(column_names, result[0])))) - return dict(zip(column_names, result[0])) diff --git a/sqlspec/base.py b/sqlspec/base.py index b01efee26..5e4c09a96 100644 --- a/sqlspec/base.py +++ b/sqlspec/base.py @@ -3,7 +3,7 @@ import contextlib import re from abc import ABC, abstractmethod -from collections.abc import Awaitable +from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field from typing import ( TYPE_CHECKING, @@ -296,6 +296,24 @@ def get_connection( config = self.get_config(name) return config.create_connection() + @overload + def get_session( + self, + name: Union[ + "type[NoPoolSyncConfig[ConnectionT, DriverT]]", + "type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + ], + ) -> "DriverT": ... + + @overload + def get_session( + self, + name: Union[ + "type[NoPoolAsyncConfig[ConnectionT, DriverT]]", + "type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + ], + ) -> "Awaitable[DriverT]": ... + def get_session( self, name: Union[ @@ -323,6 +341,28 @@ async def _create_session() -> DriverT: return _create_session() return cast("DriverT", config.driver_type(connection)) # pyright: ignore + @overload + def provide_connection( + self, + name: Union[ + "type[NoPoolSyncConfig[ConnectionT, DriverT]]", + "type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + ], + *args: Any, + **kwargs: Any, + ) -> "AbstractContextManager[ConnectionT]": ... + + @overload + def provide_connection( + self, + name: Union[ + "type[NoPoolAsyncConfig[ConnectionT, DriverT]]", + "type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + ], + *args: Any, + **kwargs: Any, + ) -> "AbstractAsyncContextManager[ConnectionT]": ... + def provide_connection( self, name: Union[ @@ -347,6 +387,28 @@ def provide_connection( config = self.get_config(name) return config.provide_connection(*args, **kwargs) + @overload + def provide_session( + self, + name: Union[ + "type[NoPoolSyncConfig[ConnectionT, DriverT]]", + "type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + ], + *args: Any, + **kwargs: Any, + ) -> "AbstractContextManager[DriverT]": ... + + @overload + def provide_session( + self, + name: Union[ + "type[NoPoolAsyncConfig[ConnectionT, DriverT]]", + "type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + ], + *args: Any, + **kwargs: Any, + ) -> "AbstractAsyncContextManager[DriverT]": ... + def provide_session( self, name: Union[ @@ -405,6 +467,24 @@ def get_pool( return cast("Union[type[PoolT], Awaitable[type[PoolT]]]", config.create_pool()) return None + @overload + def close_pool( + self, + name: Union[ + "type[NoPoolSyncConfig[ConnectionT, DriverT]]", + "type[SyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + ], + ) -> "None": ... + + @overload + def close_pool( + self, + name: Union[ + "type[NoPoolAsyncConfig[ConnectionT, DriverT]]", + "type[AsyncDatabaseConfig[ConnectionT, PoolT, DriverT]]", + ], + ) -> "Awaitable[None]": ... + def close_pool( self, name: Union[ @@ -513,6 +593,32 @@ class SyncDriverAdapterProtocol(CommonDriverAttributes[ConnectionT], ABC, Generi def __init__(self, connection: "ConnectionT", **kwargs: Any) -> None: self.connection = connection + @overload + @abstractmethod + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Sequence[dict[str, Any]]": ... + + @overload + @abstractmethod + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Sequence[ModelDTOT]": ... + @abstractmethod def select( self, @@ -523,7 +629,33 @@ def select( connection: "Optional[ConnectionT]" = None, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any, - ) -> "list[Union[ModelDTOT, dict[str, Any]]]": ... + ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": ... + + @overload + @abstractmethod + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + + @overload + @abstractmethod + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... @abstractmethod def select_one( @@ -537,6 +669,32 @@ def select_one( **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": ... + @overload + @abstractmethod + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[dict[str, Any]]": ... + + @overload + @abstractmethod + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Optional[ModelDTOT]": ... + @abstractmethod def select_one_or_none( self, @@ -549,6 +707,32 @@ def select_one_or_none( **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": ... + @overload + @abstractmethod + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Any": ... + + @overload + @abstractmethod + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "T": ... + @abstractmethod def select_value( self, @@ -559,7 +743,33 @@ def select_value( connection: Optional[ConnectionT] = None, schema_type: Optional[type[T]] = None, **kwargs: Any, - ) -> "Union[Any, T]": ... + ) -> "Union[T, Any]": ... + + @overload + @abstractmethod + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[Any]": ... + + @overload + @abstractmethod + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "Optional[T]": ... @abstractmethod def select_value_or_none( @@ -571,7 +781,7 @@ def select_value_or_none( connection: Optional[ConnectionT] = None, schema_type: Optional[type[T]] = None, **kwargs: Any, - ) -> "Optional[Union[Any, T]]": ... + ) -> "Optional[Union[T, Any]]": ... @abstractmethod def insert_update_delete( @@ -584,6 +794,32 @@ def insert_update_delete( **kwargs: Any, ) -> int: ... + @overload + @abstractmethod + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + + @overload + @abstractmethod + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... + @abstractmethod def insert_update_delete_returning( self, @@ -594,7 +830,7 @@ def insert_update_delete_returning( connection: Optional[ConnectionT] = None, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": ... + ) -> "Union[ModelDTOT, dict[str, Any]]": ... @abstractmethod def execute_script( @@ -643,6 +879,32 @@ class AsyncDriverAdapterProtocol(CommonDriverAttributes[ConnectionT], ABC, Gener def __init__(self, connection: "ConnectionT") -> None: self.connection = connection + @overload + @abstractmethod + async def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Sequence[dict[str, Any]]": ... + + @overload + @abstractmethod + async def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Sequence[ModelDTOT]": ... + @abstractmethod async def select( self, @@ -653,7 +915,33 @@ async def select( connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "list[Union[ModelDTOT, dict[str, Any]]]": ... + ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": ... + + @overload + @abstractmethod + async def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + + @overload + @abstractmethod + async def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... @abstractmethod async def select_one( @@ -667,6 +955,32 @@ async def select_one( **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": ... + @overload + @abstractmethod + async def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[dict[str, Any]]": ... + + @overload + @abstractmethod + async def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Optional[ModelDTOT]": ... + @abstractmethod async def select_one_or_none( self, @@ -679,6 +993,32 @@ async def select_one_or_none( **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": ... + @overload + @abstractmethod + async def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Any": ... + + @overload + @abstractmethod + async def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "T": ... + @abstractmethod async def select_value( self, @@ -689,7 +1029,33 @@ async def select_value( connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, - ) -> "Union[Any, T]": ... + ) -> "Union[T, Any]": ... + + @overload + @abstractmethod + async def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[Any]": ... + + @overload + @abstractmethod + async def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "Optional[T]": ... @abstractmethod async def select_value_or_none( @@ -701,7 +1067,7 @@ async def select_value_or_none( connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, - ) -> "Optional[Union[Any, T]]": ... + ) -> "Optional[Union[T, Any]]": ... @abstractmethod async def insert_update_delete( @@ -714,6 +1080,32 @@ async def insert_update_delete( **kwargs: Any, ) -> int: ... + @overload + @abstractmethod + async def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + + @overload + @abstractmethod + async def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[ConnectionT]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... + @abstractmethod async def insert_update_delete_returning( self, @@ -724,7 +1116,7 @@ async def insert_update_delete_returning( connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": ... + ) -> "Union[ModelDTOT, dict[str, Any]]": ... @abstractmethod async def execute_script( diff --git a/sqlspec/extensions/litestar/plugin.py b/sqlspec/extensions/litestar/plugin.py index 7eea9f783..fd2b5564d 100644 --- a/sqlspec/extensions/litestar/plugin.py +++ b/sqlspec/extensions/litestar/plugin.py @@ -7,6 +7,7 @@ AsyncConfigT, ConnectionT, DatabaseConfigProtocol, + DriverT, PoolT, SyncConfigT, ) @@ -75,6 +76,7 @@ def on_app_init(self, app_config: "AppConfig") -> "AppConfig": SQLSpec, ConnectionT, PoolT, + DriverT, DatabaseConfig, DatabaseConfigProtocol, SyncConfigT,