Skip to content

Commit 94c482f

Browse files
authored
fix: add TypedDict support to driver method overloads (#105)
- Adds TypedDict support to all driver methods (select, select_one, select_one_or_none, select_with_total) - Introduces SchemaT TypeVar in typing.py for unified schema type handling - Fixes Pyright type checking errors when using TypedDict classes with schema_type parameter
1 parent 8a63277 commit 94c482f

File tree

4 files changed

+52
-50
lines changed

4 files changed

+52
-50
lines changed

sqlspec/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
ModelT,
4242
PoolT,
4343
RowT,
44+
SchemaT,
4445
StatementParameters,
4546
SupportedSchemaModel,
4647
)
@@ -80,6 +81,7 @@
8081
"SQLFileLoader",
8182
"SQLResult",
8283
"SQLSpec",
84+
"SchemaT",
8385
"Select",
8486
"Statement",
8587
"StatementConfig",

sqlspec/driver/_async.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Asynchronous driver protocol implementation."""
22

33
from abc import abstractmethod
4-
from typing import TYPE_CHECKING, Any, Final, NoReturn, TypeVar, cast, overload
4+
from typing import TYPE_CHECKING, Any, Final, NoReturn, TypeVar, overload
55

66
from sqlspec.core import SQL, Statement
77
from sqlspec.driver._common import CommonDriverAttributesMixin, DataDictionaryMixin, ExecutionResult, VersionInfo
@@ -16,7 +16,7 @@
1616

1717
from sqlspec.builder import QueryBuilder
1818
from sqlspec.core import SQLResult, StatementConfig, StatementFilter
19-
from sqlspec.typing import ModelDTOT, StatementParameters
19+
from sqlspec.typing import SchemaT, StatementParameters
2020

2121
_LOGGER_NAME: Final[str] = "sqlspec"
2222
logger = get_logger(_LOGGER_NAME)
@@ -228,10 +228,10 @@ async def select_one(
228228
statement: "Statement | QueryBuilder",
229229
/,
230230
*parameters: "StatementParameters | StatementFilter",
231-
schema_type: "type[ModelDTOT]",
231+
schema_type: "type[SchemaT]",
232232
statement_config: "StatementConfig | None" = None,
233233
**kwargs: Any,
234-
) -> "ModelDTOT": ...
234+
) -> "SchemaT": ...
235235

236236
@overload
237237
async def select_one(
@@ -249,10 +249,10 @@ async def select_one(
249249
statement: "Statement | QueryBuilder",
250250
/,
251251
*parameters: "StatementParameters | StatementFilter",
252-
schema_type: "type[ModelDTOT] | None" = None,
252+
schema_type: "type[Any] | None" = None,
253253
statement_config: "StatementConfig | None" = None,
254254
**kwargs: Any,
255-
) -> "dict[str, Any] | ModelDTOT":
255+
) -> Any:
256256
"""Execute a select statement and return exactly one row.
257257
258258
Raises an exception if no rows or more than one row is returned.
@@ -273,10 +273,10 @@ async def select_one_or_none(
273273
statement: "Statement | QueryBuilder",
274274
/,
275275
*parameters: "StatementParameters | StatementFilter",
276-
schema_type: "type[ModelDTOT]",
276+
schema_type: "type[SchemaT]",
277277
statement_config: "StatementConfig | None" = None,
278278
**kwargs: Any,
279-
) -> "ModelDTOT | None": ...
279+
) -> "SchemaT | None": ...
280280

281281
@overload
282282
async def select_one_or_none(
@@ -294,10 +294,10 @@ async def select_one_or_none(
294294
statement: "Statement | QueryBuilder",
295295
/,
296296
*parameters: "StatementParameters | StatementFilter",
297-
schema_type: "type[ModelDTOT] | None" = None,
297+
schema_type: "type[Any] | None" = None,
298298
statement_config: "StatementConfig | None" = None,
299299
**kwargs: Any,
300-
) -> "dict[str, Any] | ModelDTOT | None":
300+
) -> Any:
301301
"""Execute a select statement and return at most one row.
302302
303303
Returns None if no rows are found.
@@ -311,21 +311,18 @@ async def select_one_or_none(
311311
if data_len > 1:
312312
self._raise_expected_at_most_one_row(data_len)
313313
first_row = data[0]
314-
return cast(
315-
"dict[str, Any] | ModelDTOT | None",
316-
self.to_schema(first_row, schema_type=schema_type) if schema_type else first_row,
317-
)
314+
return self.to_schema(first_row, schema_type=schema_type) if schema_type else first_row
318315

319316
@overload
320317
async def select(
321318
self,
322319
statement: "Statement | QueryBuilder",
323320
/,
324321
*parameters: "StatementParameters | StatementFilter",
325-
schema_type: "type[ModelDTOT]",
322+
schema_type: "type[SchemaT]",
326323
statement_config: "StatementConfig | None" = None,
327324
**kwargs: Any,
328-
) -> "list[ModelDTOT]": ...
325+
) -> "list[SchemaT]": ...
329326

330327
@overload
331328
async def select(
@@ -343,15 +340,13 @@ async def select(
343340
statement: "Statement | QueryBuilder",
344341
/,
345342
*parameters: "StatementParameters | StatementFilter",
346-
schema_type: "type[ModelDTOT] | None" = None,
343+
schema_type: "type[Any] | None" = None,
347344
statement_config: "StatementConfig | None" = None,
348345
**kwargs: Any,
349-
) -> "list[dict[str, Any]] | list[ModelDTOT]":
346+
) -> Any:
350347
"""Execute a select statement and return all rows."""
351348
result = await self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
352-
return cast(
353-
"list[dict[str, Any]] | list[ModelDTOT]", self.to_schema(result.get_data(), schema_type=schema_type)
354-
)
349+
return self.to_schema(result.get_data(), schema_type=schema_type)
355350

356351
async def select_value(
357352
self,
@@ -421,10 +416,10 @@ async def select_with_total(
421416
statement: "Statement | QueryBuilder",
422417
/,
423418
*parameters: "StatementParameters | StatementFilter",
424-
schema_type: "type[ModelDTOT]",
419+
schema_type: "type[SchemaT]",
425420
statement_config: "StatementConfig | None" = None,
426421
**kwargs: Any,
427-
) -> "tuple[list[ModelDTOT], int]": ...
422+
) -> "tuple[list[SchemaT], int]": ...
428423

429424
@overload
430425
async def select_with_total(
@@ -442,10 +437,10 @@ async def select_with_total(
442437
statement: "Statement | QueryBuilder",
443438
/,
444439
*parameters: "StatementParameters | StatementFilter",
445-
schema_type: "type[ModelDTOT] | None" = None,
440+
schema_type: "type[Any] | None" = None,
446441
statement_config: "StatementConfig | None" = None,
447442
**kwargs: Any,
448-
) -> "tuple[list[dict[str, Any]] | list[ModelDTOT], int]":
443+
) -> Any:
449444
"""Execute a select statement and return both the data and total count.
450445
451446
This method is designed for pagination scenarios where you need both

sqlspec/driver/_sync.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Synchronous driver protocol implementation."""
22

33
from abc import abstractmethod
4-
from typing import TYPE_CHECKING, Any, Final, NoReturn, TypeVar, cast, overload
4+
from typing import TYPE_CHECKING, Any, Final, NoReturn, TypeVar, overload
55

66
from sqlspec.core import SQL
77
from sqlspec.driver._common import CommonDriverAttributesMixin, DataDictionaryMixin, ExecutionResult, VersionInfo
@@ -16,7 +16,7 @@
1616

1717
from sqlspec.builder import QueryBuilder
1818
from sqlspec.core import SQLResult, Statement, StatementConfig, StatementFilter
19-
from sqlspec.typing import ModelDTOT, StatementParameters
19+
from sqlspec.typing import SchemaT, StatementParameters
2020

2121
_LOGGER_NAME: Final[str] = "sqlspec"
2222
logger = get_logger(_LOGGER_NAME)
@@ -228,10 +228,10 @@ def select_one(
228228
statement: "Statement | QueryBuilder",
229229
/,
230230
*parameters: "StatementParameters | StatementFilter",
231-
schema_type: "type[ModelDTOT]",
231+
schema_type: "type[SchemaT]",
232232
statement_config: "StatementConfig | None" = None,
233233
**kwargs: Any,
234-
) -> "ModelDTOT": ...
234+
) -> "SchemaT": ...
235235

236236
@overload
237237
def select_one(
@@ -249,10 +249,10 @@ def select_one(
249249
statement: "Statement | QueryBuilder",
250250
/,
251251
*parameters: "StatementParameters | StatementFilter",
252-
schema_type: "type[ModelDTOT] | None" = None,
252+
schema_type: "type[Any] | None" = None,
253253
statement_config: "StatementConfig | None" = None,
254254
**kwargs: Any,
255-
) -> "dict[str, Any] | ModelDTOT":
255+
) -> Any:
256256
"""Execute a select statement and return exactly one row.
257257
258258
Raises an exception if no rows or more than one row is returned.
@@ -273,10 +273,10 @@ def select_one_or_none(
273273
statement: "Statement | QueryBuilder",
274274
/,
275275
*parameters: "StatementParameters | StatementFilter",
276-
schema_type: "type[ModelDTOT]",
276+
schema_type: "type[SchemaT]",
277277
statement_config: "StatementConfig | None" = None,
278278
**kwargs: Any,
279-
) -> "ModelDTOT | None": ...
279+
) -> "SchemaT | None": ...
280280

281281
@overload
282282
def select_one_or_none(
@@ -294,10 +294,10 @@ def select_one_or_none(
294294
statement: "Statement | QueryBuilder",
295295
/,
296296
*parameters: "StatementParameters | StatementFilter",
297-
schema_type: "type[ModelDTOT] | None" = None,
297+
schema_type: "type[Any] | None" = None,
298298
statement_config: "StatementConfig | None" = None,
299299
**kwargs: Any,
300-
) -> "dict[str, Any] | ModelDTOT | None":
300+
) -> Any:
301301
"""Execute a select statement and return at most one row.
302302
303303
Returns None if no rows are found.
@@ -311,21 +311,18 @@ def select_one_or_none(
311311
if data_len > 1:
312312
self._raise_expected_at_most_one_row(data_len)
313313
first_row = data[0]
314-
return cast(
315-
"dict[str, Any] | ModelDTOT | None",
316-
self.to_schema(first_row, schema_type=schema_type) if schema_type else first_row,
317-
)
314+
return self.to_schema(first_row, schema_type=schema_type) if schema_type else first_row
318315

319316
@overload
320317
def select(
321318
self,
322319
statement: "Statement | QueryBuilder",
323320
/,
324321
*parameters: "StatementParameters | StatementFilter",
325-
schema_type: "type[ModelDTOT]",
322+
schema_type: "type[SchemaT]",
326323
statement_config: "StatementConfig | None" = None,
327324
**kwargs: Any,
328-
) -> "list[ModelDTOT]": ...
325+
) -> "list[SchemaT]": ...
329326

330327
@overload
331328
def select(
@@ -343,15 +340,13 @@ def select(
343340
statement: "Statement | QueryBuilder",
344341
/,
345342
*parameters: "StatementParameters | StatementFilter",
346-
schema_type: "type[ModelDTOT] | None" = None,
343+
schema_type: "type[Any] | None" = None,
347344
statement_config: "StatementConfig | None" = None,
348345
**kwargs: Any,
349-
) -> "list[dict[str, Any]] | list[ModelDTOT]":
346+
) -> Any:
350347
"""Execute a select statement and return all rows."""
351348
result = self.execute(statement, *parameters, statement_config=statement_config, **kwargs)
352-
return cast(
353-
"list[dict[str, Any]] | list[ModelDTOT]", self.to_schema(result.get_data(), schema_type=schema_type)
354-
)
349+
return self.to_schema(result.get_data(), schema_type=schema_type)
355350

356351
def select_value(
357352
self,
@@ -422,10 +417,10 @@ def select_with_total(
422417
statement: "Statement | QueryBuilder",
423418
/,
424419
*parameters: "StatementParameters | StatementFilter",
425-
schema_type: "type[ModelDTOT]",
420+
schema_type: "type[SchemaT]",
426421
statement_config: "StatementConfig | None" = None,
427422
**kwargs: Any,
428-
) -> "tuple[list[ModelDTOT], int]": ...
423+
) -> "tuple[list[SchemaT], int]": ...
429424

430425
@overload
431426
def select_with_total(
@@ -443,10 +438,10 @@ def select_with_total(
443438
statement: "Statement | QueryBuilder",
444439
/,
445440
*parameters: "StatementParameters | StatementFilter",
446-
schema_type: "type[ModelDTOT] | None" = None,
441+
schema_type: "type[Any] | None" = None,
447442
statement_config: "StatementConfig | None" = None,
448443
**kwargs: Any,
449-
) -> "tuple[list[dict[str, Any]] | list[ModelDTOT], int]":
444+
) -> Any:
450445
"""Execute a select statement and return both the data and total count.
451446
452447
This method is designed for pagination scenarios where you need both

sqlspec/typing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ def __len__(self) -> int: ...
9999
:class:`DictLike` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel` | :class:`DataclassProtocol` | :class:`AttrsInstance`
100100
"""
101101
RowT = TypeVar("RowT", bound="dict[str, Any]")
102+
SchemaT = TypeVar("SchemaT")
103+
"""Type variable for schema types (models, TypedDict, dataclasses, etc.).
104+
105+
Unbounded TypeVar for use with schema_type parameter in driver methods.
106+
Supports all schema types including TypedDict which cannot be bounded to a class hierarchy.
107+
"""
102108

103109

104110
DictRow: TypeAlias = "dict[str, Any]"
@@ -126,6 +132,9 @@ def __len__(self) -> int: ...
126132
"""Type variable for model DTOs.
127133
128134
:class:`msgspec.Struct`|:class:`pydantic.BaseModel`
135+
136+
.. deprecated:: 0.27.0
137+
Use :class:`SchemaT` instead. This TypeVar will be removed in a future version.
129138
"""
130139
PydanticOrMsgspecT = SupportedSchemaModel
131140
"""Type alias for pydantic or msgspec models.
@@ -239,6 +248,7 @@ class StorageMixin(MixinOf(DriverProtocol)): ...
239248
"PoolT_co",
240249
"PydanticOrMsgspecT",
241250
"RowT",
251+
"SchemaT",
242252
"Span",
243253
"StatementParameters",
244254
"Status",

0 commit comments

Comments
 (0)