Skip to content

Commit 4e334e2

Browse files
authored
fix(litestar): signature namespace (#41)
Corrects a few issues with Litestar signature namespace support
1 parent 26bd85d commit 4e334e2

File tree

10 files changed

+116
-39
lines changed

10 files changed

+116
-39
lines changed

docs/examples/litestar_asyncpg.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,8 @@
3232
@get("/")
3333
async def hello_world(db_session: AsyncpgDriver) -> dict[str, Any]:
3434
"""Simple endpoint that returns a greeting from the database."""
35-
# Execute a simple query that doesn't require real database tables
3635
result = await db_session.execute(SQL("SELECT 'Hello from AsyncPG!' as greeting"))
37-
38-
# Return the first row as a dictionary
39-
if result.data:
40-
return result.data[0]
41-
return {"greeting": "No data returned"}
36+
return result.get_first() or {"greeting": "No data returned"}
4237

4338

4439
@get("/version")

docs/examples/litestar_duckllm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# type: ignore
21
"""Litestar DuckLLM
32
43
This example demonstrates how to use the Litestar framework with the DuckLLM extension.
@@ -27,7 +26,8 @@ class ChatMessage(Struct):
2726

2827
@post("/chat", sync_to_thread=True)
2928
def duckllm_chat(db_session: DuckDBDriver, data: ChatMessage) -> ChatMessage:
30-
return db_session.execute("SELECT open_prompt(?)", data.message, schema_type=ChatMessage).get_first()
29+
results = db_session.execute("SELECT open_prompt(?)", data.message, schema_type=ChatMessage).get_first()
30+
return results or ChatMessage(message="No response from DuckLLM")
3131

3232

3333
sqlspec = SQLSpec(

docs/examples/litestar_multi_db.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# type: ignore
21
"""Litestar Multi DB
32
43
This example demonstrates how to use multiple databases in a Litestar application.
@@ -19,22 +18,21 @@
1918
from sqlspec.adapters.aiosqlite import AiosqliteConfig, AiosqliteDriver
2019
from sqlspec.adapters.duckdb import DuckDBConfig, DuckDBDriver
2120
from sqlspec.extensions.litestar import DatabaseConfig, SQLSpec
21+
from sqlspec.statement.sql import SQL
2222

2323

2424
@get("/test", sync_to_thread=True)
2525
def simple_select(etl_session: DuckDBDriver) -> dict[str, str]:
26-
from sqlspec.statement.sql import SQL
27-
2826
result = etl_session.execute(SQL("SELECT 'Hello, ETL world!' AS greeting"))
2927
greeting = result.get_first()
3028
return {"greeting": greeting["greeting"] if greeting is not None else "hi"}
3129

3230

3331
@get("/")
3432
async def simple_sqlite(db_session: AiosqliteDriver) -> dict[str, str]:
35-
from sqlspec.statement.sql import SQL
36-
37-
return await db_session.select_one(SQL("SELECT 'Hello, world!' AS greeting"))
33+
result = await db_session.execute("SELECT 'Hello, world!' AS greeting")
34+
greeting = result.get_first()
35+
return {"greeting": greeting["greeting"] if greeting is not None else "hi"}
3836

3937

4038
sqlspec = SQLSpec(

docs/examples/unified_storage_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def demo_unified_storage_architecture() -> None:
3737
range AS id,
3838
'Product_' || (range % 10) AS product_name,
3939
(random() * 1000)::int AS amount,
40-
DATE '2024-01-01' + (range % 365) AS sale_date
40+
DATE '2024-01-01' + (range % 365)::INTEGER AS sale_date
4141
FROM range(1000)
4242
""")
4343
)

sqlspec/adapters/aiosqlite/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,16 @@ async def provide_session(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Aio
186186
async def provide_pool(self, *args: Any, **kwargs: Any) -> None:
187187
"""Aiosqlite doesn't support pooling."""
188188
return
189+
190+
def get_signature_namespace(self) -> "dict[str, type[Any]]":
191+
"""Get the signature namespace for Aiosqlite types.
192+
193+
This provides all Aiosqlite-specific types that Litestar needs to recognize
194+
to avoid serialization attempts.
195+
196+
Returns:
197+
Dictionary mapping type names to types.
198+
"""
199+
namespace = super().get_signature_namespace()
200+
namespace.update({"AiosqliteConnection": AiosqliteConnection})
201+
return namespace

sqlspec/adapters/asyncmy/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
77

88
import asyncmy
9+
from asyncmy.pool import Pool as AsyncmyPool
910

1011
from sqlspec.adapters.asyncmy.driver import AsyncmyConnection, AsyncmyDriver
1112
from sqlspec.config import AsyncDatabaseConfig
@@ -283,3 +284,16 @@ async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool": # pyright: i
283284
if not self.pool_instance:
284285
self.pool_instance = await self.create_pool()
285286
return self.pool_instance
287+
288+
def get_signature_namespace(self) -> "dict[str, type[Any]]":
289+
"""Get the signature namespace for Asyncmy types.
290+
291+
This provides all Asyncmy-specific types that Litestar needs to recognize
292+
to avoid serialization attempts.
293+
294+
Returns:
295+
Dictionary mapping type names to types.
296+
"""
297+
namespace = super().get_signature_namespace()
298+
namespace.update({"AsyncmyConnection": AsyncmyConnection, "AsyncmyPool": AsyncmyPool})
299+
return namespace

sqlspec/adapters/asyncpg/config.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
from contextlib import asynccontextmanager
66
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict
77

8-
from asyncpg import Record
8+
from asyncpg import Connection, Record
99
from asyncpg import create_pool as asyncpg_create_pool
10+
from asyncpg.connection import ConnectionMeta
11+
from asyncpg.pool import Pool, PoolConnectionProxy, PoolConnectionProxyMeta
1012
from typing_extensions import NotRequired, Unpack
1113

1214
from sqlspec.adapters.asyncpg.driver import AsyncpgConnection, AsyncpgDriver
@@ -18,7 +20,6 @@
1820
if TYPE_CHECKING:
1921
from asyncio.events import AbstractEventLoop
2022

21-
from asyncpg.pool import Pool
2223
from sqlglot.dialects.dialect import DialectType
2324

2425

@@ -347,24 +348,15 @@ def get_signature_namespace(self) -> "dict[str, type[Any]]":
347348
Dictionary mapping type names to types.
348349
"""
349350
namespace = super().get_signature_namespace()
350-
351-
try:
352-
from asyncpg import Connection, Record
353-
from asyncpg.connection import ConnectionMeta
354-
from asyncpg.pool import Pool, PoolConnectionProxy, PoolConnectionProxyMeta
355-
356-
namespace.update(
357-
{
358-
"Connection": Connection,
359-
"Pool": Pool,
360-
"PoolConnectionProxy": PoolConnectionProxy,
361-
"PoolConnectionProxyMeta": PoolConnectionProxyMeta,
362-
"ConnectionMeta": ConnectionMeta,
363-
"Record": Record,
364-
"AsyncpgConnection": type(AsyncpgConnection), # The Union type alias
365-
}
366-
)
367-
except ImportError:
368-
logger.warning("Failed to import AsyncPG types for signature namespace")
369-
351+
namespace.update(
352+
{
353+
"Connection": Connection,
354+
"Pool": Pool,
355+
"PoolConnectionProxy": PoolConnectionProxy,
356+
"PoolConnectionProxyMeta": PoolConnectionProxyMeta,
357+
"ConnectionMeta": ConnectionMeta,
358+
"Record": Record,
359+
"AsyncpgConnection": type(AsyncpgConnection),
360+
}
361+
)
370362
return namespace

sqlspec/adapters/oracledb/config.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,19 @@ def provide_pool(self, *args: Any, **kwargs: Any) -> "ConnectionPool":
315315
self.pool_instance = self.create_pool()
316316
return self.pool_instance
317317

318+
def get_signature_namespace(self) -> "dict[str, type[Any]]":
319+
"""Get the signature namespace for OracleDB types.
320+
321+
This provides all OracleDB-specific types that Litestar needs to recognize
322+
to avoid serialization attempts.
323+
324+
Returns:
325+
Dictionary mapping type names to types.
326+
"""
327+
namespace = super().get_signature_namespace()
328+
namespace.update({"OracleSyncConnection": OracleSyncConnection, "OracleAsyncConnection": OracleAsyncConnection})
329+
return namespace
330+
318331
@property
319332
def connection_config_dict(self) -> dict[str, Any]:
320333
"""Return the connection configuration as a dict for Oracle operations.
@@ -624,3 +637,16 @@ async def provide_pool(self, *args: Any, **kwargs: Any) -> "AsyncConnectionPool"
624637
if not self.pool_instance:
625638
self.pool_instance = await self.create_pool()
626639
return self.pool_instance
640+
641+
def get_signature_namespace(self) -> "dict[str, type[Any]]":
642+
"""Get the signature namespace for OracleDB async types.
643+
644+
This provides all OracleDB async-specific types that Litestar needs to recognize
645+
to avoid serialization attempts.
646+
647+
Returns:
648+
Dictionary mapping type names to types.
649+
"""
650+
namespace = super().get_signature_namespace()
651+
namespace.update({"OracleSyncConnection": OracleSyncConnection, "OracleAsyncConnection": OracleAsyncConnection})
652+
return namespace

sqlspec/adapters/psqlpy/config.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def pool_config_dict(self) -> dict[str, Any]:
324324

325325
return config
326326

327-
async def _create_pool(self) -> ConnectionPool:
327+
async def _create_pool(self) -> "ConnectionPool":
328328
"""Create the actual async connection pool."""
329329
logger.info("Creating psqlpy connection pool", extra={"adapter": "psqlpy"})
330330

@@ -351,7 +351,7 @@ async def _close_pool(self) -> None:
351351
logger.exception("Failed to close psqlpy connection pool", extra={"adapter": "psqlpy", "error": str(e)})
352352
raise
353353

354-
async def create_connection(self) -> PsqlpyConnection:
354+
async def create_connection(self) -> "PsqlpyConnection":
355355
"""Create a single async connection (not from pool).
356356
357357
Returns:
@@ -413,3 +413,16 @@ async def provide_pool(self, *args: Any, **kwargs: Any) -> ConnectionPool:
413413
if not self.pool_instance:
414414
self.pool_instance = await self.create_pool()
415415
return self.pool_instance
416+
417+
def get_signature_namespace(self) -> "dict[str, type[Any]]":
418+
"""Get the signature namespace for Psqlpy types.
419+
420+
This provides all Psqlpy-specific types that Litestar needs to recognize
421+
to avoid serialization attempts.
422+
423+
Returns:
424+
Dictionary mapping type names to types.
425+
"""
426+
namespace = super().get_signature_namespace()
427+
namespace.update({"PsqlpyConnection": PsqlpyConnection})
428+
return namespace

sqlspec/adapters/psycopg/config.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,19 @@ def provide_pool(self, *args: Any, **kwargs: Any) -> "ConnectionPool":
400400
self.pool_instance = self.create_pool()
401401
return self.pool_instance
402402

403+
def get_signature_namespace(self) -> "dict[str, type[Any]]":
404+
"""Get the signature namespace for Psycopg types.
405+
406+
This provides all Psycopg-specific types that Litestar needs to recognize
407+
to avoid serialization attempts.
408+
409+
Returns:
410+
Dictionary mapping type names to types.
411+
"""
412+
namespace = super().get_signature_namespace()
413+
namespace.update({"PsycopgSyncConnection": PsycopgSyncConnection})
414+
return namespace
415+
403416

404417
class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnectionPool, PsycopgAsyncDriver]):
405418
"""Configuration for Psycopg asynchronous database connections with direct field-based configuration."""
@@ -727,3 +740,16 @@ async def provide_pool(self, *args: Any, **kwargs: Any) -> "AsyncConnectionPool"
727740
if not self.pool_instance:
728741
self.pool_instance = await self.create_pool()
729742
return self.pool_instance
743+
744+
def get_signature_namespace(self) -> "dict[str, type[Any]]":
745+
"""Get the signature namespace for Psycopg async types.
746+
747+
This provides all Psycopg async-specific types that Litestar needs to recognize
748+
to avoid serialization attempts.
749+
750+
Returns:
751+
Dictionary mapping type names to types.
752+
"""
753+
namespace = super().get_signature_namespace()
754+
namespace.update({"PsycopgAsyncConnection": PsycopgAsyncConnection})
755+
return namespace

0 commit comments

Comments
 (0)