Skip to content

Commit eb63bfc

Browse files
authored
feat: add TypedDict support to schema_type conversion (#91)
Adds `TypedDict` support to the to_schema method's `schema_type` parameter
1 parent 1177c6b commit eb63bfc

File tree

4 files changed

+131
-4
lines changed

4 files changed

+131
-4
lines changed

sqlspec/driver/mixins/_result_tools.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from enum import Enum
88
from functools import partial
99
from pathlib import Path, PurePath
10-
from typing import Any, Callable, Final, Optional, overload
10+
from typing import Any, Callable, Final, Optional, TypeVar, Union, cast, overload
1111
from uuid import UUID
1212

1313
from mypy_extensions import trait
@@ -33,13 +33,16 @@
3333
is_dict,
3434
is_msgspec_struct,
3535
is_pydantic_model,
36+
is_typed_dict,
3637
)
3738

3839
__all__ = ("_DEFAULT_TYPE_DECODERS", "_default_msgspec_deserializer")
3940

4041

4142
logger = logging.getLogger(__name__)
4243

44+
# TypeVar for TypedDict support - not bound since TypedDict classes aren't dict subclasses
45+
TypedDictT = TypeVar("TypedDictT")
4346

4447
_DATETIME_TYPES: Final[set[type]] = {datetime.datetime, datetime.date, datetime.time}
4548

@@ -165,10 +168,20 @@ def to_schema(data: "list[ModelT]", *, schema_type: None = None) -> "list[ModelT
165168
def to_schema(data: "ModelT") -> "ModelT": ...
166169
@overload
167170
@staticmethod
168-
def to_schema(data: Any, *, schema_type: None = None) -> Any: ...
171+
def to_schema(data: "dict[str, Any]", *, schema_type: "type[TypedDictT]") -> "TypedDictT": ...
172+
@overload
173+
@staticmethod
174+
def to_schema(data: "list[dict[str, Any]]", *, schema_type: "type[TypedDictT]") -> "list[TypedDictT]": ...
175+
@overload
176+
@staticmethod
177+
def to_schema(data: Any, *, schema_type: "type[TypedDictT]") -> Any: ...
169178

179+
@overload
170180
@staticmethod
171-
def to_schema(data: Any, *, schema_type: "Optional[type[ModelDTOT]]" = None) -> Any:
181+
def to_schema(data: Any, *, schema_type: None = None) -> Any: ...
182+
183+
@staticmethod # type: ignore[misc,unused-ignore]
184+
def to_schema(data: Any, *, schema_type: "Optional[type[Union[ModelDTOT, TypedDictT]]]" = None) -> Any: # type: ignore[misc,unused-ignore]
172185
"""Convert data to a specified schema type.
173186
174187
Args:
@@ -183,6 +196,10 @@ def to_schema(data: Any, *, schema_type: "Optional[type[ModelDTOT]]" = None) ->
183196
"""
184197
if schema_type is None:
185198
return data
199+
if is_typed_dict(schema_type):
200+
if isinstance(data, list):
201+
return cast("list[ModelDTOT]", [item for item in data if is_dict(item)])
202+
return cast("ModelDTOT", data)
186203
if is_dataclass(schema_type):
187204
if isinstance(data, list):
188205
result: list[Any] = []
@@ -273,5 +290,5 @@ def _convert_numpy_arrays_in_data(obj: Any) -> Any:
273290
if isinstance(data, dict):
274291
return schema_type(**data)
275292
return data
276-
msg = "`schema_type` should be a valid Dataclass, Pydantic model, Msgspec struct, or Attrs class"
293+
msg = "`schema_type` should be a valid Dataclass, Pydantic model, Msgspec struct, Attrs class, or TypedDict"
277294
raise SQLSpecError(msg)

sqlspec/utils/type_guards.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from functools import lru_cache
1010
from typing import TYPE_CHECKING, Any, Optional, Union, cast
1111

12+
from typing_extensions import is_typeddict
13+
1214
from sqlspec.typing import (
1315
ATTRS_INSTALLED,
1416
LITESTAR_INSTALLED,
@@ -117,6 +119,7 @@
117119
"is_select_builder",
118120
"is_statement_filter",
119121
"is_string_literal",
122+
"is_typed_dict",
120123
"is_typed_parameter",
121124
"schema_dump",
122125
"supports_limit",
@@ -126,6 +129,18 @@
126129
)
127130

128131

132+
def is_typed_dict(obj: Any) -> "TypeGuard[type]":
133+
"""Check if an object is a TypedDict class.
134+
135+
Args:
136+
obj: The object to check
137+
138+
Returns:
139+
True if the object is a TypedDict class, False otherwise
140+
"""
141+
return is_typeddict(obj)
142+
143+
129144
def is_statement_filter(obj: Any) -> "TypeGuard[StatementFilter]":
130145
"""Check if an object implements the StatementFilter protocol.
131146

tests/unit/test_driver/test_result_tools.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import msgspec
1212
import pytest
13+
from typing_extensions import TypedDict
1314

1415
from sqlspec.driver.mixins._result_tools import (
1516
_DEFAULT_TYPE_DECODERS,
@@ -39,6 +40,14 @@ class SampleMsgspecStructWithIntList(msgspec.Struct):
3940
values: "list[int]"
4041

4142

43+
class SampleTypedDict(TypedDict):
44+
"""Sample TypedDict for testing."""
45+
46+
name: str
47+
age: int
48+
optional_field: "Optional[str]"
49+
50+
4251
# Test _is_list_type_target function
4352
def test_is_list_type_target_with_list_types() -> None:
4453
"""Test detection of list type targets."""
@@ -290,6 +299,57 @@ def test_to_schema_mixin_without_schema_type() -> None:
290299
assert result == test_data
291300

292301

302+
def test_to_schema_mixin_with_typeddict_single_record() -> None:
303+
"""Test ToSchemaMixin.to_schema with TypedDict for single record."""
304+
test_data = {"name": "test_user", "age": 30, "optional_field": "value"}
305+
306+
result = ToSchemaMixin.to_schema(test_data, schema_type=SampleTypedDict)
307+
308+
assert result == test_data
309+
assert isinstance(result, dict)
310+
311+
312+
def test_to_schema_mixin_with_typeddict_multiple_records() -> None:
313+
"""Test ToSchemaMixin.to_schema with TypedDict for multiple records."""
314+
test_data = [
315+
{"name": "user1", "age": 25, "optional_field": "value1"},
316+
{"name": "user2", "age": 30, "optional_field": "value2"},
317+
]
318+
319+
result = ToSchemaMixin.to_schema(test_data, schema_type=SampleTypedDict)
320+
321+
assert isinstance(result, list)
322+
assert len(result) == 2
323+
for item in result:
324+
assert isinstance(item, dict)
325+
assert result == test_data
326+
327+
328+
def test_to_schema_mixin_with_typeddict_mixed_data() -> None:
329+
"""Test ToSchemaMixin.to_schema with TypedDict filters non-dict items."""
330+
test_data = [
331+
{"name": "user1", "age": 25, "optional_field": "value1"},
332+
"not_a_dict", # This should be filtered out
333+
{"name": "user2", "age": 30, "optional_field": "value2"},
334+
]
335+
336+
result = ToSchemaMixin.to_schema(test_data, schema_type=SampleTypedDict)
337+
338+
assert isinstance(result, list)
339+
assert len(result) == 2 # Only dict items should be included
340+
for item in result:
341+
assert isinstance(item, dict)
342+
343+
344+
def test_to_schema_mixin_with_typeddict_non_dict_data() -> None:
345+
"""Test ToSchemaMixin.to_schema with TypedDict returns non-dict data unchanged."""
346+
test_data = "not_a_dict"
347+
348+
result = ToSchemaMixin.to_schema(test_data, schema_type=SampleTypedDict)
349+
350+
assert result == test_data
351+
352+
293353
@pytest.mark.skipif(not NUMPY_INSTALLED, reason="numpy not installed")
294354
def test_numpy_array_conversion_edge_cases() -> None:
295355
"""Test edge cases for numpy array conversion."""

tests/unit/test_utils/test_type_guards.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import msgspec
1111
import pytest
1212
from sqlglot import exp
13+
from typing_extensions import TypedDict
1314

1415
from sqlspec.utils.type_guards import (
1516
dataclass_to_dict,
@@ -57,6 +58,7 @@
5758
is_schema_with_field,
5859
is_schema_without_field,
5960
is_string_literal,
61+
is_typed_dict,
6062
schema_dump,
6163
)
6264

@@ -74,6 +76,14 @@ class SampleDataclass:
7476
optional_field: "Optional[str]" = None
7577

7678

79+
class SampleTypedDict(TypedDict):
80+
"""Sample TypedDict for testing."""
81+
82+
name: str
83+
age: int
84+
optional_field: "Optional[str]"
85+
86+
7787
class MockSQLGlotExpression:
7888
"""Mock SQLGlot expression for testing type guard functions.
7989
@@ -932,6 +942,31 @@ def test_get_msgspec_rename_config_with_pascal_rename() -> None:
932942
assert result == "pascal"
933943

934944

945+
def test_is_typed_dict_with_typeddict_class() -> None:
946+
"""Test is_typed_dict returns True for TypedDict classes."""
947+
assert is_typed_dict(SampleTypedDict) is True
948+
949+
950+
def test_is_typed_dict_with_typeddict_instance() -> None:
951+
"""Test is_typed_dict returns False for TypedDict instances (they are dicts)."""
952+
sample_data: SampleTypedDict = {"name": "test", "age": 25, "optional_field": "value"}
953+
assert is_typed_dict(sample_data) is False
954+
955+
956+
def test_is_typed_dict_with_non_typeddict() -> None:
957+
"""Test is_typed_dict returns False for non-TypedDict types."""
958+
assert is_typed_dict(dict) is False
959+
assert is_typed_dict(SampleDataclass) is False
960+
assert is_typed_dict(str) is False
961+
assert is_typed_dict(42) is False
962+
assert is_typed_dict({}) is False
963+
964+
965+
def test_is_typed_dict_with_regular_dict() -> None:
966+
"""Test is_typed_dict returns False for regular dict instances."""
967+
assert is_typed_dict({"key": "value"}) is False
968+
969+
935970
def test_get_msgspec_rename_config_without_rename() -> None:
936971
"""Test get_msgspec_rename_config returns None when no rename config."""
937972
schema_type = MockMsgspecStructWithoutRename

0 commit comments

Comments
 (0)