Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ benchmarks/
*.duckdb
.crush
CRUSH.md
*.md
!README.md
!CONTRIBUTING.md
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ include = [
"sqlspec/core/**/*.py", # Core module
"sqlspec/loader.py", # Loader module

# === ADAPTER TYPE CONVERTERS ===
"sqlspec/adapters/adbc/type_converter.py", # ADBC type converter
"sqlspec/adapters/bigquery/type_converter.py", # BigQuery type converter
"sqlspec/adapters/duckdb/type_converter.py", # DuckDB type converter
"sqlspec/adapters/oracledb/type_converter.py", # Oracle type converter
"sqlspec/adapters/psqlpy/type_converter.py", # Psqlpy type converter

# === UTILITY MODULES ===
"sqlspec/utils/text.py", # Text utilities
"sqlspec/utils/sync_tools.py", # Synchronous utility functions
Expand Down
244 changes: 223 additions & 21 deletions sqlspec/_serialization.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
"""Enhanced serialization module with byte-aware encoding and class-based architecture.

Provides a Protocol-based serialization system that users can extend.
Supports msgspec, orjson, and standard library JSON with automatic fallback.
"""

import contextlib
import datetime
import enum
from typing import Any
import json
from abc import ABC, abstractmethod
from typing import Any, Final, Literal, Optional, Protocol, Union, overload

from sqlspec.typing import PYDANTIC_INSTALLED, BaseModel
from sqlspec.typing import MSGSPEC_INSTALLED, ORJSON_INSTALLED, PYDANTIC_INSTALLED, BaseModel


def _type_to_string(value: Any) -> str: # pragma: no cover
"""Convert special types to strings for JSON serialization.

Args:
value: Value to convert.

Returns:
String representation of the value.
"""
if isinstance(value, datetime.datetime):
return convert_datetime_to_gmt_iso(value)
if isinstance(value, datetime.date):
Expand All @@ -20,35 +37,206 @@ def _type_to_string(value: Any) -> str: # pragma: no cover
raise TypeError from exc


try:
from msgspec.json import Decoder, Encoder
class JSONSerializer(Protocol):
"""Protocol for JSON serialization implementations.

encoder, decoder = Encoder(enc_hook=_type_to_string), Decoder()
decode_json = decoder.decode
Users can implement this protocol to create custom serializers.
"""

def encode_json(data: Any) -> str: # pragma: no cover
return encoder.encode(data).decode("utf-8")
def encode(self, data: Any, *, as_bytes: bool = False) -> Union[str, bytes]:
"""Encode data to JSON.

except ImportError:
try:
from orjson import ( # pyright: ignore[reportMissingImports]
Args:
data: Data to encode.
as_bytes: Whether to return bytes instead of string.

Returns:
JSON string or bytes depending on as_bytes parameter.
"""
...

def decode(self, data: Union[str, bytes], *, decode_bytes: bool = True) -> Any:
"""Decode from JSON.

Args:
data: JSON string or bytes to decode.
decode_bytes: Whether to decode bytes input.

Returns:
Decoded Python object.
"""
...


class BaseJSONSerializer(ABC):
"""Base class for JSON serializers with common functionality."""

__slots__ = ()

@abstractmethod
def encode(self, data: Any, *, as_bytes: bool = False) -> Union[str, bytes]:
"""Encode data to JSON."""
...

@abstractmethod
def decode(self, data: Union[str, bytes], *, decode_bytes: bool = True) -> Any:
"""Decode from JSON."""
...


class MsgspecSerializer(BaseJSONSerializer):
"""Msgspec-based JSON serializer for optimal performance."""

__slots__ = ("_decoder", "_encoder")

def __init__(self) -> None:
"""Initialize msgspec encoder and decoder."""
from msgspec.json import Decoder, Encoder

self._encoder: Final[Encoder] = Encoder(enc_hook=_type_to_string)
self._decoder: Final[Decoder] = Decoder()

def encode(self, data: Any, *, as_bytes: bool = False) -> Union[str, bytes]:
"""Encode data using msgspec."""
try:
if as_bytes:
return self._encoder.encode(data)
return self._encoder.encode(data).decode("utf-8")
except (TypeError, ValueError):
if ORJSON_INSTALLED:
return OrjsonSerializer().encode(data, as_bytes=as_bytes)
return StandardLibSerializer().encode(data, as_bytes=as_bytes)

def decode(self, data: Union[str, bytes], *, decode_bytes: bool = True) -> Any:
"""Decode data using msgspec."""
if isinstance(data, bytes):
if decode_bytes:
try:
return self._decoder.decode(data)
except (TypeError, ValueError):
if ORJSON_INSTALLED:
return OrjsonSerializer().decode(data, decode_bytes=decode_bytes)
return StandardLibSerializer().decode(data, decode_bytes=decode_bytes)
return data

try:
return self._decoder.decode(data.encode("utf-8"))
except (TypeError, ValueError):
if ORJSON_INSTALLED:
return OrjsonSerializer().decode(data, decode_bytes=decode_bytes)
return StandardLibSerializer().decode(data, decode_bytes=decode_bytes)


class OrjsonSerializer(BaseJSONSerializer):
"""Orjson-based JSON serializer with native datetime/UUID support."""

__slots__ = ()

def encode(self, data: Any, *, as_bytes: bool = False) -> Union[str, bytes]:
"""Encode data using orjson."""
from orjson import (
OPT_NAIVE_UTC, # pyright: ignore[reportUnknownVariableType]
OPT_SERIALIZE_NUMPY, # pyright: ignore[reportUnknownVariableType]
OPT_SERIALIZE_UUID, # pyright: ignore[reportUnknownVariableType]
)
from orjson import dumps as _encode_json # pyright: ignore[reportUnknownVariableType,reportMissingImports]
from orjson import loads as decode_json # type: ignore[no-redef,assignment,unused-ignore]
from orjson import dumps as _orjson_dumps # pyright: ignore[reportMissingImports]

result = _orjson_dumps(
data, default=_type_to_string, option=OPT_SERIALIZE_NUMPY | OPT_NAIVE_UTC | OPT_SERIALIZE_UUID
)
return result if as_bytes else result.decode("utf-8")

def decode(self, data: Union[str, bytes], *, decode_bytes: bool = True) -> Any:
"""Decode data using orjson."""
from orjson import loads as _orjson_loads # pyright: ignore[reportMissingImports]

if isinstance(data, bytes):
if decode_bytes:
return _orjson_loads(data)
return data
return _orjson_loads(data)

def encode_json(data: Any) -> str: # pragma: no cover
return _encode_json(
data, default=_type_to_string, option=OPT_SERIALIZE_NUMPY | OPT_NAIVE_UTC | OPT_SERIALIZE_UUID
).decode("utf-8")

except ImportError:
from json import dumps as encode_json # type: ignore[assignment]
from json import loads as decode_json # type: ignore[assignment]
class StandardLibSerializer(BaseJSONSerializer):
"""Standard library JSON serializer as fallback."""

__all__ = ("convert_date_to_iso", "convert_datetime_to_gmt_iso", "decode_json", "encode_json")
__slots__ = ()

def encode(self, data: Any, *, as_bytes: bool = False) -> Union[str, bytes]:
"""Encode data using standard library json."""
json_str = json.dumps(data, default=_type_to_string)
return json_str.encode("utf-8") if as_bytes else json_str

def decode(self, data: Union[str, bytes], *, decode_bytes: bool = True) -> Any:
"""Decode data using standard library json."""
if isinstance(data, bytes):
if decode_bytes:
return json.loads(data.decode("utf-8"))
return data
return json.loads(data)


_default_serializer: Optional[JSONSerializer] = None


def get_default_serializer() -> JSONSerializer:
"""Get the default serializer based on available libraries.

Priority: msgspec > orjson > stdlib

Returns:
The best available JSON serializer.
"""
global _default_serializer

if _default_serializer is None:
if MSGSPEC_INSTALLED:
with contextlib.suppress(ImportError):
_default_serializer = MsgspecSerializer()

if _default_serializer is None and ORJSON_INSTALLED:
with contextlib.suppress(ImportError):
_default_serializer = OrjsonSerializer()

if _default_serializer is None:
_default_serializer = StandardLibSerializer()

assert _default_serializer is not None
return _default_serializer


@overload
def encode_json(data: Any, *, as_bytes: Literal[False] = ...) -> str: ... # pragma: no cover


@overload
def encode_json(data: Any, *, as_bytes: Literal[True]) -> bytes: ... # pragma: no cover


def encode_json(data: Any, *, as_bytes: bool = False) -> Union[str, bytes]:
"""Encode to JSON, optionally returning bytes for optimal performance.

Args:
data: The data to encode.
as_bytes: Whether to return bytes instead of string.

Returns:
JSON string or bytes depending on as_bytes parameter.
"""
return get_default_serializer().encode(data, as_bytes=as_bytes)


def decode_json(data: Union[str, bytes], *, decode_bytes: bool = True) -> Any:
"""Decode from JSON string or bytes efficiently.

Args:
data: JSON string or bytes to decode.
decode_bytes: Whether to decode bytes input.

Returns:
Decoded Python object.
"""
return get_default_serializer().decode(data, decode_bytes=decode_bytes)


def convert_datetime_to_gmt_iso(dt: datetime.datetime) -> str: # pragma: no cover
Expand All @@ -75,3 +263,17 @@ def convert_date_to_iso(dt: datetime.date) -> str: # pragma: no cover
The ISO formatted date string.
"""
return dt.isoformat()


__all__ = (
"BaseJSONSerializer",
"JSONSerializer",
"MsgspecSerializer",
"OrjsonSerializer",
"StandardLibSerializer",
"convert_date_to_iso",
"convert_datetime_to_gmt_iso",
"decode_json",
"encode_json",
"get_default_serializer",
)
9 changes: 9 additions & 0 deletions sqlspec/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ class UnsetTypeStub(enum.Enum):
MSGSPEC_INSTALLED = False # pyright: ignore[reportConstantRedefinition]


try:
import orjson # noqa: F401

ORJSON_INSTALLED = True # pyright: ignore[reportConstantRedefinition]
except ImportError:
ORJSON_INSTALLED = False # pyright: ignore[reportConstantRedefinition]


# Always define stub type for DTOData
@runtime_checkable
class DTODataStub(Protocol[T]):
Expand Down Expand Up @@ -621,6 +629,7 @@ async def insert_returning(self, conn: Any, query_name: str, sql: str, parameter
"NUMPY_INSTALLED",
"OBSTORE_INSTALLED",
"OPENTELEMETRY_INSTALLED",
"ORJSON_INSTALLED",
"PGVECTOR_INSTALLED",
"PROMETHEUS_INSTALLED",
"PYARROW_INSTALLED",
Expand Down
Loading