Skip to content

Commit 51371a5

Browse files
authored
fix: to_schema detect msgspec rename configurations (#69)
When populating a struct from a dictionary, the rename configurations aren't automatically handled. This change detects the type and applies the conversion automatically.
1 parent cb63b6f commit 51371a5

File tree

9 files changed

+817
-166
lines changed

9 files changed

+817
-166
lines changed

pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,11 @@ include = [
161161
"sqlspec/loader.py", # Loader module
162162

163163
# === UTILITY MODULES ===
164-
"sqlspec/utils/text.py", # Text utilities
165-
"sqlspec/utils/sync_tools.py", # Synchronous utility functions
166-
"sqlspec/utils/type_guards.py", # Type guard utilities
167-
"sqlspec/utils/fixtures.py", # File fixture loading
164+
"sqlspec/utils/text.py", # Text utilities
165+
"sqlspec/utils/sync_tools.py", # Synchronous utility functions
166+
"sqlspec/utils/type_guards.py", # Type guard utilities
167+
"sqlspec/utils/fixtures.py", # File fixture loading
168+
"sqlspec/utils/data_transformation.py", # Data transformation utilities
168169
]
169170
mypy-args = [
170171
"--ignore-missing-imports",

sqlspec/driver/mixins/_result_tools.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# ruff: noqa: C901
12
"""Result handling and schema conversion mixins for database drivers."""
23

34
import datetime
@@ -22,7 +23,16 @@
2223
convert,
2324
get_type_adapter,
2425
)
25-
from sqlspec.utils.type_guards import is_attrs_schema, is_dataclass, is_msgspec_struct, is_pydantic_model
26+
from sqlspec.utils.data_transformation import transform_dict_keys
27+
from sqlspec.utils.text import camelize, kebabize, pascalize
28+
from sqlspec.utils.type_guards import (
29+
get_msgspec_rename_config,
30+
is_attrs_schema,
31+
is_dataclass,
32+
is_dict,
33+
is_msgspec_struct,
34+
is_pydantic_model,
35+
)
2636

2737
__all__ = ("_DEFAULT_TYPE_DECODERS", "_default_msgspec_deserializer")
2838

@@ -143,21 +153,46 @@ def to_schema(data: Any, *, schema_type: "Optional[type[ModelDTOT]]" = None) ->
143153
if isinstance(data, list):
144154
result: list[Any] = []
145155
for item in data:
146-
if hasattr(item, "keys"):
156+
if is_dict(item):
147157
result.append(schema_type(**dict(item))) # type: ignore[operator]
148158
else:
149159
result.append(item)
150160
return result
151-
if hasattr(data, "keys"):
161+
if is_dict(data):
152162
return schema_type(**dict(data)) # type: ignore[operator]
153163
if isinstance(data, dict):
154164
return schema_type(**data) # type: ignore[operator]
155165
return data
156166
if is_msgspec_struct(schema_type):
167+
rename_config = get_msgspec_rename_config(schema_type) # type: ignore[arg-type]
157168
deserializer = partial(_default_msgspec_deserializer, type_decoders=_DEFAULT_TYPE_DECODERS)
158-
if not isinstance(data, Sequence):
159-
return convert(obj=data, type=schema_type, from_attributes=True, dec_hook=deserializer)
160-
return convert(obj=data, type=list[schema_type], from_attributes=True, dec_hook=deserializer) # type: ignore[valid-type]
169+
170+
# Transform field names if rename configuration exists
171+
transformed_data = data
172+
if (rename_config and is_dict(data)) or (isinstance(data, Sequence) and data and is_dict(data[0])):
173+
try:
174+
converter = None
175+
if rename_config == "camel":
176+
converter = camelize
177+
elif rename_config == "kebab":
178+
converter = kebabize
179+
elif rename_config == "pascal":
180+
converter = pascalize
181+
182+
if converter is not None:
183+
if isinstance(data, Sequence):
184+
transformed_data = [
185+
transform_dict_keys(item, converter) if is_dict(item) else item for item in data
186+
]
187+
else:
188+
transformed_data = transform_dict_keys(data, converter) if is_dict(data) else data
189+
except Exception as e:
190+
logger.debug("Field name transformation failed for msgspec schema: %s", e)
191+
transformed_data = data
192+
193+
if not isinstance(transformed_data, Sequence):
194+
return convert(obj=transformed_data, type=schema_type, from_attributes=True, dec_hook=deserializer)
195+
return convert(obj=transformed_data, type=list[schema_type], from_attributes=True, dec_hook=deserializer) # type: ignore[valid-type]
161196
if is_pydantic_model(schema_type):
162197
if not isinstance(data, Sequence):
163198
adapter = get_type_adapter(schema_type)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""Data transformation utilities for SQLSpec.
2+
3+
Provides functions for transforming data structures, particularly for
4+
field name conversion when mapping database results to schema objects.
5+
Used primarily for msgspec field name conversion with rename configurations.
6+
"""
7+
8+
from typing import Any, Callable, Union
9+
10+
__all__ = ("transform_dict_keys",)
11+
12+
13+
def _safe_convert_key(key: Any, converter: Callable[[str], str]) -> Any:
14+
"""Safely convert a key using the converter function.
15+
16+
Args:
17+
key: Key to convert (may not be a string).
18+
converter: Function to convert string keys.
19+
20+
Returns:
21+
Converted key if conversion succeeds, original key otherwise.
22+
"""
23+
if not isinstance(key, str):
24+
return key
25+
26+
try:
27+
return converter(key)
28+
except (TypeError, ValueError, AttributeError):
29+
# If conversion fails, return the original key
30+
return key
31+
32+
33+
def transform_dict_keys(data: Union[dict, list, Any], converter: Callable[[str], str]) -> Union[dict, list, Any]:
34+
"""Transform dictionary keys using the provided converter function.
35+
36+
Recursively transforms all dictionary keys in a data structure using
37+
the provided converter function. Handles nested dictionaries, lists
38+
of dictionaries, and preserves non-dict values unchanged.
39+
40+
Args:
41+
data: The data structure to transform. Can be a dict, list, or any other type.
42+
converter: Function to convert string keys (e.g., camelize, kebabize).
43+
44+
Returns:
45+
The transformed data structure with converted keys. Non-dict values
46+
are returned unchanged.
47+
48+
Examples:
49+
Transform snake_case keys to camelCase:
50+
51+
>>> from sqlspec.utils.text import camelize
52+
>>> data = {"user_id": 123, "created_at": "2024-01-01"}
53+
>>> transform_dict_keys(data, camelize)
54+
{"userId": 123, "createdAt": "2024-01-01"}
55+
56+
Transform nested structures:
57+
58+
>>> nested = {
59+
... "user_data": {"first_name": "John", "last_name": "Doe"},
60+
... "order_items": [
61+
... {"item_id": 1, "item_name": "Product A"},
62+
... {"item_id": 2, "item_name": "Product B"},
63+
... ],
64+
... }
65+
>>> transform_dict_keys(nested, camelize)
66+
{
67+
"userData": {
68+
"firstName": "John",
69+
"lastName": "Doe"
70+
},
71+
"orderItems": [
72+
{"itemId": 1, "itemName": "Product A"},
73+
{"itemId": 2, "itemName": "Product B"}
74+
]
75+
}
76+
"""
77+
if isinstance(data, dict):
78+
return _transform_dict(data, converter)
79+
if isinstance(data, list):
80+
return _transform_list(data, converter)
81+
return data
82+
83+
84+
def _transform_dict(data: dict, converter: Callable[[str], str]) -> dict:
85+
"""Transform a dictionary's keys recursively.
86+
87+
Args:
88+
data: Dictionary to transform.
89+
converter: Function to convert string keys.
90+
91+
Returns:
92+
Dictionary with transformed keys and recursively transformed values.
93+
"""
94+
transformed = {}
95+
96+
for key, value in data.items():
97+
# Convert the key using the provided converter function
98+
# Use safe conversion that handles edge cases without try-except
99+
converted_key = _safe_convert_key(key, converter)
100+
101+
# Recursively transform the value
102+
transformed_value = transform_dict_keys(value, converter)
103+
104+
transformed[converted_key] = transformed_value
105+
106+
return transformed
107+
108+
109+
def _transform_list(data: list, converter: Callable[[str], str]) -> list:
110+
"""Transform a list's elements recursively.
111+
112+
Args:
113+
data: List to transform.
114+
converter: Function to convert string keys in nested structures.
115+
116+
Returns:
117+
List with recursively transformed elements.
118+
"""
119+
# Use list comprehension for better performance and avoid try-except in loop
120+
return [transform_dict_keys(item, converter) for item in data]

sqlspec/utils/text.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,7 @@
1919
_SNAKE_CASE_REMOVE_NON_WORD = re.compile(r"[^\w]+", re.UNICODE)
2020
_SNAKE_CASE_MULTIPLE_UNDERSCORES = re.compile(r"__+", re.UNICODE)
2121

22-
__all__ = ("camelize", "check_email", "slugify", "snake_case")
23-
24-
25-
def check_email(email: str) -> str:
26-
"""Validate an email address.
27-
28-
Args:
29-
email: The email to validate.
30-
31-
Raises:
32-
ValueError: If the email is invalid.
33-
34-
Returns:
35-
The validated email.
36-
"""
37-
if "@" not in email:
38-
msg = "Invalid email!"
39-
raise ValueError(msg)
40-
return email.lower()
22+
__all__ = ("camelize", "kebabize", "pascalize", "slugify", "snake_case")
4123

4224

4325
def slugify(value: str, allow_unicode: bool = False, separator: Optional[str] = None) -> str:
@@ -80,6 +62,32 @@ def camelize(string: str) -> str:
8062
return "".join(word if index == 0 else word.capitalize() for index, word in enumerate(string.split("_")))
8163

8264

65+
@lru_cache(maxsize=100)
66+
def kebabize(string: str) -> str:
67+
"""Convert a string to kebab-case.
68+
69+
Args:
70+
string: The string to convert.
71+
72+
Returns:
73+
The kebab-case version of the string.
74+
"""
75+
return "-".join(word.lower() for word in string.split("_") if word)
76+
77+
78+
@lru_cache(maxsize=100)
79+
def pascalize(string: str) -> str:
80+
"""Convert a string to PascalCase.
81+
82+
Args:
83+
string: The string to convert.
84+
85+
Returns:
86+
The PascalCase version of the string.
87+
"""
88+
return "".join(word.capitalize() for word in string.split("_") if word)
89+
90+
8391
@lru_cache(maxsize=100)
8492
def snake_case(string: str) -> str:
8593
"""Convert a string to snake_case.

sqlspec/utils/type_guards.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from collections.abc import Sequence
88
from collections.abc import Set as AbstractSet
9+
from functools import lru_cache
910
from typing import TYPE_CHECKING, Any, Optional, Union, cast
1011

1112
from sqlspec.typing import (
@@ -59,6 +60,7 @@
5960
"extract_dataclass_items",
6061
"get_initial_expression",
6162
"get_literal_parent",
63+
"get_msgspec_rename_config",
6264
"get_node_expressions",
6365
"get_node_this",
6466
"get_param_style_and_name",
@@ -429,6 +431,78 @@ def is_msgspec_struct_without_field(obj: Any, field_name: str) -> "TypeGuard[Str
429431
return False
430432

431433

434+
@lru_cache(maxsize=500)
435+
def _detect_rename_pattern(field_name: str, encode_name: str) -> "Optional[str]":
436+
"""Detect the rename pattern by comparing field name transformations.
437+
438+
Args:
439+
field_name: Original field name (e.g., "user_id")
440+
encode_name: Encoded field name (e.g., "userId")
441+
442+
Returns:
443+
The detected rename pattern ("camel", "kebab", "pascal") or None
444+
"""
445+
from sqlspec.utils.text import camelize, kebabize, pascalize
446+
447+
# Test camelCase conversion
448+
if encode_name == camelize(field_name) and encode_name != field_name:
449+
return "camel"
450+
451+
if encode_name == kebabize(field_name) and encode_name != field_name:
452+
return "kebab"
453+
454+
if encode_name == pascalize(field_name) and encode_name != field_name:
455+
return "pascal"
456+
return None
457+
458+
459+
def get_msgspec_rename_config(schema_type: type) -> "Optional[str]":
460+
"""Extract msgspec rename configuration from a struct type.
461+
462+
Analyzes field name transformations to detect the rename pattern used by msgspec.
463+
Since msgspec doesn't store the original rename parameter directly, we infer it
464+
by comparing field names with their encode_name values.
465+
466+
Args:
467+
schema_type: The msgspec struct type to inspect.
468+
469+
Returns:
470+
The rename configuration value ("camel", "kebab", "pascal", etc.) if detected,
471+
None if no rename configuration exists or if not a msgspec struct.
472+
473+
Examples:
474+
>>> class User(msgspec.Struct, rename="camel"):
475+
... user_id: int
476+
>>> get_msgspec_rename_config(User)
477+
"camel"
478+
479+
>>> class Product(msgspec.Struct):
480+
... product_id: int
481+
>>> get_msgspec_rename_config(Product)
482+
None
483+
"""
484+
if not MSGSPEC_INSTALLED:
485+
return None
486+
487+
if not is_msgspec_struct(schema_type):
488+
return None
489+
490+
from msgspec import structs
491+
492+
fields = structs.fields(schema_type) # type: ignore[arg-type]
493+
if not fields:
494+
return None
495+
496+
# Check if any field name differs from its encode_name
497+
for field in fields:
498+
if field.name != field.encode_name:
499+
# Detect the rename pattern by comparing transformations
500+
return _detect_rename_pattern(field.name, field.encode_name)
501+
502+
# If all field names match their encode_name, no rename is applied
503+
return None
504+
505+
432506
def is_attrs_instance(obj: Any) -> "TypeGuard[AttrsInstanceStub]":
433507
"""Check if a value is an attrs class instance.
434508

0 commit comments

Comments
 (0)