Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,11 @@ include = [
"sqlspec/loader.py", # Loader module

# === UTILITY MODULES ===
"sqlspec/utils/text.py", # Text utilities
"sqlspec/utils/sync_tools.py", # Synchronous utility functions
"sqlspec/utils/type_guards.py", # Type guard utilities
"sqlspec/utils/fixtures.py", # File fixture loading
"sqlspec/utils/text.py", # Text utilities
"sqlspec/utils/sync_tools.py", # Synchronous utility functions
"sqlspec/utils/type_guards.py", # Type guard utilities
"sqlspec/utils/fixtures.py", # File fixture loading
"sqlspec/utils/data_transformation.py", # Data transformation utilities
]
mypy-args = [
"--ignore-missing-imports",
Expand Down
47 changes: 41 additions & 6 deletions sqlspec/driver/mixins/_result_tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# ruff: noqa: C901
"""Result handling and schema conversion mixins for database drivers."""

import datetime
Expand All @@ -22,7 +23,16 @@
convert,
get_type_adapter,
)
from sqlspec.utils.type_guards import is_attrs_schema, is_dataclass, is_msgspec_struct, is_pydantic_model
from sqlspec.utils.data_transformation import transform_dict_keys
from sqlspec.utils.text import camelize, kebabize, pascalize
from sqlspec.utils.type_guards import (
get_msgspec_rename_config,
is_attrs_schema,
is_dataclass,
is_dict,
is_msgspec_struct,
is_pydantic_model,
)

__all__ = ("_DEFAULT_TYPE_DECODERS", "_default_msgspec_deserializer")

Expand Down Expand Up @@ -143,21 +153,46 @@ def to_schema(data: Any, *, schema_type: "Optional[type[ModelDTOT]]" = None) ->
if isinstance(data, list):
result: list[Any] = []
for item in data:
if hasattr(item, "keys"):
if is_dict(item):
result.append(schema_type(**dict(item))) # type: ignore[operator]
else:
result.append(item)
return result
if hasattr(data, "keys"):
if is_dict(data):
return schema_type(**dict(data)) # type: ignore[operator]
if isinstance(data, dict):
return schema_type(**data) # type: ignore[operator]
return data
if is_msgspec_struct(schema_type):
rename_config = get_msgspec_rename_config(schema_type) # type: ignore[arg-type]
deserializer = partial(_default_msgspec_deserializer, type_decoders=_DEFAULT_TYPE_DECODERS)
if not isinstance(data, Sequence):
return convert(obj=data, type=schema_type, from_attributes=True, dec_hook=deserializer)
return convert(obj=data, type=list[schema_type], from_attributes=True, dec_hook=deserializer) # type: ignore[valid-type]

# Transform field names if rename configuration exists
transformed_data = data
if (rename_config and is_dict(data)) or (isinstance(data, Sequence) and data and is_dict(data[0])):
try:
converter = None
if rename_config == "camel":
converter = camelize
elif rename_config == "kebab":
converter = kebabize
elif rename_config == "pascal":
converter = pascalize

if converter is not None:
if isinstance(data, Sequence):
transformed_data = [
transform_dict_keys(item, converter) if is_dict(item) else item for item in data
]
else:
transformed_data = transform_dict_keys(data, converter) if is_dict(data) else data
except Exception as e:
logger.debug("Field name transformation failed for msgspec schema: %s", e)
transformed_data = data

if not isinstance(transformed_data, Sequence):
return convert(obj=transformed_data, type=schema_type, from_attributes=True, dec_hook=deserializer)
return convert(obj=transformed_data, type=list[schema_type], from_attributes=True, dec_hook=deserializer) # type: ignore[valid-type]
if is_pydantic_model(schema_type):
if not isinstance(data, Sequence):
adapter = get_type_adapter(schema_type)
Expand Down
120 changes: 120 additions & 0 deletions sqlspec/utils/data_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Data transformation utilities for SQLSpec.

Provides functions for transforming data structures, particularly for
field name conversion when mapping database results to schema objects.
Used primarily for msgspec field name conversion with rename configurations.
"""

from typing import Any, Callable, Union

__all__ = ("transform_dict_keys",)


def _safe_convert_key(key: Any, converter: Callable[[str], str]) -> Any:
"""Safely convert a key using the converter function.

Args:
key: Key to convert (may not be a string).
converter: Function to convert string keys.

Returns:
Converted key if conversion succeeds, original key otherwise.
"""
if not isinstance(key, str):
return key

try:
return converter(key)
except (TypeError, ValueError, AttributeError):
# If conversion fails, return the original key
return key


def transform_dict_keys(data: Union[dict, list, Any], converter: Callable[[str], str]) -> Union[dict, list, Any]:
"""Transform dictionary keys using the provided converter function.

Recursively transforms all dictionary keys in a data structure using
the provided converter function. Handles nested dictionaries, lists
of dictionaries, and preserves non-dict values unchanged.

Args:
data: The data structure to transform. Can be a dict, list, or any other type.
converter: Function to convert string keys (e.g., camelize, kebabize).

Returns:
The transformed data structure with converted keys. Non-dict values
are returned unchanged.

Examples:
Transform snake_case keys to camelCase:

>>> from sqlspec.utils.text import camelize
>>> data = {"user_id": 123, "created_at": "2024-01-01"}
>>> transform_dict_keys(data, camelize)
{"userId": 123, "createdAt": "2024-01-01"}

Transform nested structures:

>>> nested = {
... "user_data": {"first_name": "John", "last_name": "Doe"},
... "order_items": [
... {"item_id": 1, "item_name": "Product A"},
... {"item_id": 2, "item_name": "Product B"},
... ],
... }
>>> transform_dict_keys(nested, camelize)
{
"userData": {
"firstName": "John",
"lastName": "Doe"
},
"orderItems": [
{"itemId": 1, "itemName": "Product A"},
{"itemId": 2, "itemName": "Product B"}
]
}
"""
if isinstance(data, dict):
return _transform_dict(data, converter)
if isinstance(data, list):
return _transform_list(data, converter)
return data


def _transform_dict(data: dict, converter: Callable[[str], str]) -> dict:
"""Transform a dictionary's keys recursively.

Args:
data: Dictionary to transform.
converter: Function to convert string keys.

Returns:
Dictionary with transformed keys and recursively transformed values.
"""
transformed = {}

for key, value in data.items():
# Convert the key using the provided converter function
# Use safe conversion that handles edge cases without try-except
converted_key = _safe_convert_key(key, converter)

# Recursively transform the value
transformed_value = transform_dict_keys(value, converter)

transformed[converted_key] = transformed_value

return transformed


def _transform_list(data: list, converter: Callable[[str], str]) -> list:
"""Transform a list's elements recursively.

Args:
data: List to transform.
converter: Function to convert string keys in nested structures.

Returns:
List with recursively transformed elements.
"""
# Use list comprehension for better performance and avoid try-except in loop
return [transform_dict_keys(item, converter) for item in data]
46 changes: 27 additions & 19 deletions sqlspec/utils/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,7 @@
_SNAKE_CASE_REMOVE_NON_WORD = re.compile(r"[^\w]+", re.UNICODE)
_SNAKE_CASE_MULTIPLE_UNDERSCORES = re.compile(r"__+", re.UNICODE)

__all__ = ("camelize", "check_email", "slugify", "snake_case")


def check_email(email: str) -> str:
"""Validate an email address.

Args:
email: The email to validate.

Raises:
ValueError: If the email is invalid.

Returns:
The validated email.
"""
if "@" not in email:
msg = "Invalid email!"
raise ValueError(msg)
return email.lower()
__all__ = ("camelize", "kebabize", "pascalize", "slugify", "snake_case")


def slugify(value: str, allow_unicode: bool = False, separator: Optional[str] = None) -> str:
Expand Down Expand Up @@ -80,6 +62,32 @@ def camelize(string: str) -> str:
return "".join(word if index == 0 else word.capitalize() for index, word in enumerate(string.split("_")))


@lru_cache(maxsize=100)
def kebabize(string: str) -> str:
"""Convert a string to kebab-case.

Args:
string: The string to convert.

Returns:
The kebab-case version of the string.
"""
return "-".join(word.lower() for word in string.split("_") if word)


@lru_cache(maxsize=100)
def pascalize(string: str) -> str:
"""Convert a string to PascalCase.

Args:
string: The string to convert.

Returns:
The PascalCase version of the string.
"""
return "".join(word.capitalize() for word in string.split("_") if word)


@lru_cache(maxsize=100)
def snake_case(string: str) -> str:
"""Convert a string to snake_case.
Expand Down
74 changes: 74 additions & 0 deletions sqlspec/utils/type_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from collections.abc import Sequence
from collections.abc import Set as AbstractSet
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Optional, Union, cast

from sqlspec.typing import (
Expand Down Expand Up @@ -59,6 +60,7 @@
"extract_dataclass_items",
"get_initial_expression",
"get_literal_parent",
"get_msgspec_rename_config",
"get_node_expressions",
"get_node_this",
"get_param_style_and_name",
Expand Down Expand Up @@ -429,6 +431,78 @@ def is_msgspec_struct_without_field(obj: Any, field_name: str) -> "TypeGuard[Str
return False


@lru_cache(maxsize=500)
def _detect_rename_pattern(field_name: str, encode_name: str) -> "Optional[str]":
"""Detect the rename pattern by comparing field name transformations.

Args:
field_name: Original field name (e.g., "user_id")
encode_name: Encoded field name (e.g., "userId")

Returns:
The detected rename pattern ("camel", "kebab", "pascal") or None
"""
from sqlspec.utils.text import camelize, kebabize, pascalize

# Test camelCase conversion
if encode_name == camelize(field_name) and encode_name != field_name:
return "camel"

if encode_name == kebabize(field_name) and encode_name != field_name:
return "kebab"

if encode_name == pascalize(field_name) and encode_name != field_name:
return "pascal"
return None


def get_msgspec_rename_config(schema_type: type) -> "Optional[str]":
"""Extract msgspec rename configuration from a struct type.

Analyzes field name transformations to detect the rename pattern used by msgspec.
Since msgspec doesn't store the original rename parameter directly, we infer it
by comparing field names with their encode_name values.

Args:
schema_type: The msgspec struct type to inspect.

Returns:
The rename configuration value ("camel", "kebab", "pascal", etc.) if detected,
None if no rename configuration exists or if not a msgspec struct.

Examples:
>>> class User(msgspec.Struct, rename="camel"):
... user_id: int
>>> get_msgspec_rename_config(User)
"camel"

>>> class Product(msgspec.Struct):
... product_id: int
>>> get_msgspec_rename_config(Product)
None
"""
if not MSGSPEC_INSTALLED:
return None

if not is_msgspec_struct(schema_type):
return None

from msgspec import structs

fields = structs.fields(schema_type) # type: ignore[arg-type]
if not fields:
return None

# Check if any field name differs from its encode_name
for field in fields:
if field.name != field.encode_name:
# Detect the rename pattern by comparing transformations
return _detect_rename_pattern(field.name, field.encode_name)

# If all field names match their encode_name, no rename is applied
return None


def is_attrs_instance(obj: Any) -> "TypeGuard[AttrsInstanceStub]":
"""Check if a value is an attrs class instance.

Expand Down
Loading