Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/provinzkraut/unasyncd
rev: "v0.9.0"
rev: "v0.10.0"
hooks:
- id: unasyncd
additional_dependencies: ["ruff"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.14.13"
rev: "v0.14.14"
hooks:
# Run the linter.
- id: ruff
Expand Down
98 changes: 96 additions & 2 deletions advanced_alchemy/_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# ruff: noqa: PLR0911
import datetime
import decimal
import enum
from typing import Any
import uuid
from typing import Any, ClassVar, Protocol, Union, cast

from typing_extensions import runtime_checkable

Expand All @@ -11,7 +14,6 @@

PYDANTIC_INSTALLED = True
except ImportError:
from typing import ClassVar, Protocol

@runtime_checkable
class BaseModel(Protocol): # type: ignore[no-redef]
Expand Down Expand Up @@ -90,3 +92,95 @@ def convert_date_to_iso(dt: datetime.date) -> str: # pragma: no cover
str: The ISO 8601 formatted date string.
"""
return dt.isoformat()


def encode_complex_type(obj: Any) -> Any:
"""Convert an object to a JSON-serializable format if possible.

Handles types that are not natively JSON serializable:
- datetime, date, time: ISO format strings
- timedelta: total seconds as float
- Decimal: string representation
- bytes: hex string
- UUID: string representation
- set, frozenset: list

Args:
obj: The object to encode.

Returns:
A JSON-serializable representation of the object, or None if the type is not supported.
"""
if isinstance(obj, datetime.datetime):
return {"__type__": "datetime", "value": obj.isoformat()}
if isinstance(obj, datetime.date):
return {"__type__": "date", "value": obj.isoformat()}
if isinstance(obj, datetime.time):
return {"__type__": "time", "value": obj.isoformat()}
if isinstance(obj, datetime.timedelta):
return {"__type__": "timedelta", "value": obj.total_seconds()}
if isinstance(obj, decimal.Decimal):
return {"__type__": "decimal", "value": str(obj)}
if isinstance(obj, bytes):
return {"__type__": "bytes", "value": obj.hex()}
if isinstance(obj, uuid.UUID):
return {"__type__": "uuid", "value": str(obj)}
if isinstance(obj, (set, frozenset)):
items: list[Any] = list(cast("Union[set[Any], frozenset[Any]]", obj)) # type: ignore[redundant-cast]
return {"__type__": "set", "value": items}
return None


def decode_complex_type(value: Any) -> Any:
"""Recursively decode special type markers.

Decodes the special ``{"__type__": ..., "value": ...}`` structures.
"""
if isinstance(value, list):
value_list = cast("list[Any]", value) # type: ignore[redundant-cast]
return [decode_complex_type(v) for v in value_list]

if not isinstance(value, dict):
return value

# Decode any nested values first
value_dict = cast("dict[Any, Any]", value) # type: ignore[redundant-cast]
decoded: dict[str, Any] = {str(k): decode_complex_type(v) for k, v in value_dict.items()}

# Then decode "typed" marker dicts
if "__type__" in decoded and "value" in decoded:
return _decode_typed_marker(decoded)

return decoded


def _decode_typed_marker(obj: dict[str, Any]) -> Any:
"""Custom JSON decoder for special types.

Args:
obj: The dictionary to decode.

Returns:
The decoded object, or the original dict if not a special type.
"""
type_name = obj["__type__"]
value = obj["value"]

if type_name == "datetime":
return datetime.datetime.fromisoformat(value)
if type_name == "date":
return datetime.date.fromisoformat(value)
if type_name == "time":
return datetime.time.fromisoformat(value)
if type_name == "timedelta":
return datetime.timedelta(seconds=value)
if type_name == "decimal":
return decimal.Decimal(value)
if type_name == "bytes":
return bytes.fromhex(value)
if type_name == "uuid":
return uuid.UUID(value)
if type_name == "set":
return set(value)

return obj
16 changes: 13 additions & 3 deletions advanced_alchemy/alembic/templates/asyncio/script.py.mako
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,21 @@ import sqlalchemy as sa
from alembic import op
from advanced_alchemy.types import EncryptedString, EncryptedText, GUID, ORA_JSONB, DateTimeUTC, StoredObject, PasswordHash, FernetBackend
from advanced_alchemy.types.encrypted_string import PGCryptoBackend
from advanced_alchemy.types.password_hash.argon2 import Argon2Hasher
from advanced_alchemy.types.password_hash.passlib import PasslibHasher
from advanced_alchemy.types.password_hash.pwdlib import PwdlibHasher
from sqlalchemy import Text # noqa: F401
${imports if imports else ""}
try:
from advanced_alchemy.types.password_hash.argon2 import Argon2Hasher
except ImportError:
Argon2Hasher = Any # type: ignore
try:
from advanced_alchemy.types.password_hash.passlib import PasslibHasher
except ImportError:
PasslibHasher = Any # type: ignore
try:
from advanced_alchemy.types.password_hash.pwdlib import PwdlibHasher
except ImportError:
PwdlibHasher = Any # type: ignore

if TYPE_CHECKING:
from collections.abc import Sequence

Expand Down
16 changes: 13 additions & 3 deletions advanced_alchemy/alembic/templates/sync/script.py.mako
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,21 @@ import sqlalchemy as sa
from alembic import op
from advanced_alchemy.types import EncryptedString, EncryptedText, GUID, ORA_JSONB, DateTimeUTC, StoredObject, PasswordHash, FernetBackend
from advanced_alchemy.types.encrypted_string import PGCryptoBackend
from advanced_alchemy.types.password_hash.argon2 import Argon2Hasher
from advanced_alchemy.types.password_hash.passlib import PasslibHasher
from advanced_alchemy.types.password_hash.pwdlib import PwdlibHasher
from sqlalchemy import Text # noqa: F401
${imports if imports else ""}
try:
from advanced_alchemy.types.password_hash.argon2 import Argon2Hasher
except ImportError:
Argon2Hasher = Any # type: ignore
try:
from advanced_alchemy.types.password_hash.passlib import PasslibHasher
except ImportError:
PasslibHasher = Any # type: ignore
try:
from advanced_alchemy.types.password_hash.pwdlib import PwdlibHasher
except ImportError:
PwdlibHasher = Any # type: ignore

if TYPE_CHECKING:
from collections.abc import Sequence

Expand Down
125 changes: 13 additions & 112 deletions advanced_alchemy/cache/serializers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Serialization utilities for caching SQLAlchemy models."""

from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import Any, TypeVar, Union, cast
from uuid import UUID
from typing import Any, TypeVar

from sqlalchemy import inspect as sa_inspect

from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy._serialization import (
decode_complex_type,
decode_json,
encode_complex_type,
encode_json,
)

__all__ = (
"default_deserializer",
Expand All @@ -23,107 +25,6 @@
"""Metadata key for the table name in serialized data."""


def _json_encoder(obj: Any) -> Any: # noqa: PLR0911
"""Custom JSON encoder for SQLAlchemy model attributes.

Handles types that are not natively JSON serializable:
- datetime, date, time: ISO format strings
- timedelta: total seconds as float
- Decimal: string representation
- bytes: hex string
- UUID: string representation
- set, frozenset: list

Args:
obj: The object to encode.

Returns:
A JSON-serializable representation of the object.

Raises:
TypeError: If the object type is not supported.
"""
if isinstance(obj, datetime):
return {"__type__": "datetime", "value": obj.isoformat()}
if isinstance(obj, date):
return {"__type__": "date", "value": obj.isoformat()}
if isinstance(obj, time):
return {"__type__": "time", "value": obj.isoformat()}
if isinstance(obj, timedelta):
return {"__type__": "timedelta", "value": obj.total_seconds()}
if isinstance(obj, Decimal):
return {"__type__": "decimal", "value": str(obj)}
if isinstance(obj, bytes):
return {"__type__": "bytes", "value": obj.hex()}
if isinstance(obj, UUID):
return {"__type__": "uuid", "value": str(obj)}
if isinstance(obj, (set, frozenset)):
items: list[Any] = list(cast("Union[set[Any], frozenset[Any]]", obj)) # type: ignore[redundant-cast]
return {"__type__": "set", "value": items}
msg = f"Object of type {type(obj).__name__} is not JSON serializable"
raise TypeError(msg)


def _json_decoder(obj: dict[str, Any]) -> Any: # noqa: PLR0911
"""Custom JSON decoder for special types.

Args:
obj: The dictionary to decode.

Returns:
The decoded object, or the original dict if not a special type.
"""
if "__type__" not in obj:
return obj

type_name = obj["__type__"]
value = obj["value"]

if type_name == "datetime":
return datetime.fromisoformat(value)
if type_name == "date":
return date.fromisoformat(value)
if type_name == "time":
return time.fromisoformat(value)
if type_name == "timedelta":
return timedelta(seconds=value)
if type_name == "decimal":
return Decimal(value)
if type_name == "bytes":
return bytes.fromhex(value)
if type_name == "uuid":
return UUID(value)
if type_name == "set":
return set(value)

return obj


def _decode_special_types(value: Any) -> Any:
"""Recursively decode special type markers.

When using ``encode_json`` (msgspec/orjson/json fallback), we can't rely on
stdlib json's ``object_hook`` callback. This helper decodes the special
``{"__type__": ..., "value": ...}`` structures produced by ``_json_encoder``.
"""
if isinstance(value, list):
value_list = cast("list[Any]", value) # type: ignore[redundant-cast]
return [_decode_special_types(v) for v in value_list]

if not isinstance(value, dict):
return value

# Decode any nested values first
value_dict = cast("dict[Any, Any]", value) # type: ignore[redundant-cast]
decoded: dict[str, Any] = {str(k): _decode_special_types(v) for k, v in value_dict.items()}

# Then decode "typed" marker dicts
if "__type__" in decoded and "value" in decoded:
return _json_decoder(decoded)

return decoded


def default_serializer(model: Any) -> bytes:
"""Serialize a SQLAlchemy model instance to JSON bytes.

Expand Down Expand Up @@ -160,11 +61,11 @@ def default_serializer(model: Any) -> bytes:
if getattr(column, "_insert_sentinel", False):
continue
value = getattr(model, column.key)
try:
# Encode special types into JSON-friendly marker structures.
data[column.key] = _json_encoder(value)
except TypeError:
# Leave unknown types alone; encode_json has its own hooks/fallbacks.

# Encode special types into JSON-friendly marker structures.
if (encoded := encode_complex_type(value)) is not None:
data[column.key] = encoded
else:
data[column.key] = value

return encode_json(data).encode("utf-8")
Expand Down Expand Up @@ -201,7 +102,7 @@ def default_deserializer(data: bytes, model_class: type[T]) -> T:
# user is a detached User instance
"""
parsed_raw = decode_json(data)
parsed = _decode_special_types(parsed_raw)
parsed = decode_complex_type(parsed_raw)

# Validate model class matches
serialized_model = parsed.pop(_MODEL_KEY, None)
Expand Down
Loading