Skip to content

Commit 8a63277

Browse files
authored
feat: add optional numpy serialization support (#104)
Implements optional NumPy array serialization support in SQLSpec's serialization system
1 parent 9e07857 commit 8a63277

File tree

4 files changed

+422
-9
lines changed

4 files changed

+422
-9
lines changed

sqlspec/_serialization.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
33
Provides a Protocol-based serialization system that users can extend.
44
Supports msgspec, orjson, and standard library JSON with automatic fallback.
5+
6+
Features optional numpy array serialization when numpy is installed.
7+
Arrays are automatically converted to lists during JSON encoding.
58
"""
69

710
import contextlib
@@ -11,17 +14,23 @@
1114
from abc import ABC, abstractmethod
1215
from typing import Any, Final, Literal, Protocol, overload
1316

17+
from sqlspec._typing import NUMPY_INSTALLED
1418
from sqlspec.typing import MSGSPEC_INSTALLED, ORJSON_INSTALLED, PYDANTIC_INSTALLED, BaseModel
1519

1620

17-
def _type_to_string(value: Any) -> str: # pragma: no cover
21+
def _type_to_string(value: Any) -> Any: # pragma: no cover
1822
"""Convert special types to strings for JSON serialization.
1923
24+
Handles datetime, date, enums, Pydantic models, and numpy arrays.
25+
2026
Args:
2127
value: Value to convert.
2228
2329
Returns:
24-
String representation of the value.
30+
Serializable representation of the value (string, list, dict, etc.).
31+
32+
Raises:
33+
TypeError: If value cannot be serialized.
2534
"""
2635
if isinstance(value, datetime.datetime):
2736
return convert_datetime_to_gmt_iso(value)
@@ -31,10 +40,16 @@ def _type_to_string(value: Any) -> str: # pragma: no cover
3140
return str(value.value)
3241
if PYDANTIC_INSTALLED and isinstance(value, BaseModel):
3342
return value.model_dump_json()
43+
if NUMPY_INSTALLED:
44+
import numpy as np
45+
46+
if isinstance(value, np.ndarray):
47+
return value.tolist()
3448
try:
3549
return str(value)
3650
except Exception as exc:
37-
raise TypeError from exc
51+
msg = f"Cannot serialize {type(value).__name__}"
52+
raise TypeError(msg) from exc
3853

3954

4055
class JSONSerializer(Protocol):
@@ -128,22 +143,37 @@ def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any:
128143

129144

130145
class OrjsonSerializer(BaseJSONSerializer):
131-
"""Orjson-based JSON serializer with native datetime/UUID support."""
146+
"""Orjson-based JSON serializer with native datetime/UUID support.
147+
148+
Automatically enables numpy serialization if numpy is installed.
149+
"""
132150

133151
__slots__ = ()
134152

135153
def encode(self, data: Any, *, as_bytes: bool = False) -> str | bytes:
136-
"""Encode data using orjson."""
154+
"""Encode data using orjson.
155+
156+
Args:
157+
data: Data to encode.
158+
as_bytes: Whether to return bytes instead of string.
159+
160+
Returns:
161+
JSON string or bytes depending on as_bytes parameter.
162+
"""
137163
from orjson import (
138164
OPT_NAIVE_UTC, # pyright: ignore[reportUnknownVariableType]
139-
OPT_SERIALIZE_NUMPY, # pyright: ignore[reportUnknownVariableType]
140165
OPT_SERIALIZE_UUID, # pyright: ignore[reportUnknownVariableType]
141166
)
142167
from orjson import dumps as _orjson_dumps # pyright: ignore[reportMissingImports]
143168

144-
result = _orjson_dumps(
145-
data, default=_type_to_string, option=OPT_SERIALIZE_NUMPY | OPT_NAIVE_UTC | OPT_SERIALIZE_UUID
146-
)
169+
options = OPT_NAIVE_UTC | OPT_SERIALIZE_UUID
170+
171+
if NUMPY_INSTALLED:
172+
from orjson import OPT_SERIALIZE_NUMPY # pyright: ignore[reportUnknownVariableType]
173+
174+
options |= OPT_SERIALIZE_NUMPY
175+
176+
result = _orjson_dumps(data, default=_type_to_string, option=options)
147177
return result if as_bytes else result.decode("utf-8")
148178

149179
def decode(self, data: str | bytes, *, decode_bytes: bool = True) -> Any:

sqlspec/_typing.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,24 @@ def slice(self, offset: int = 0, length: "int | None" = None) -> Any:
389389
PYARROW_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
390390

391391

392+
@runtime_checkable
393+
class NumpyArrayStub(Protocol):
394+
"""Protocol stub for numpy.ndarray when numpy is not installed.
395+
396+
Provides minimal interface for type checking and serialization support.
397+
"""
398+
399+
def tolist(self) -> "list[Any]":
400+
"""Convert array to Python list."""
401+
...
402+
403+
404+
try:
405+
from numpy import ndarray as NumpyArray # noqa: N812
406+
except ImportError:
407+
NumpyArray = NumpyArrayStub # type: ignore[assignment,misc]
408+
409+
392410
try:
393411
from opentelemetry import trace # pyright: ignore[reportMissingImports, reportAssignmentType]
394412
from opentelemetry.trace import ( # pyright: ignore[reportMissingImports, reportAssignmentType]
@@ -660,6 +678,8 @@ async def insert_returning(self, conn: Any, query_name: str, sql: str, parameter
660678
"FailFastStub",
661679
"Gauge",
662680
"Histogram",
681+
"NumpyArray",
682+
"NumpyArrayStub",
663683
"Span",
664684
"Status",
665685
"StatusCode",

sqlspec/typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
FailFast,
4242
Gauge,
4343
Histogram,
44+
NumpyArray,
4445
Span,
4546
Status,
4647
StatusCode,
@@ -233,6 +234,7 @@ class StorageMixin(MixinOf(DriverProtocol)): ...
233234
"ModelDictList",
234235
"ModelDictList",
235236
"ModelT",
237+
"NumpyArray",
236238
"PoolT",
237239
"PoolT_co",
238240
"PydanticOrMsgspecT",

0 commit comments

Comments
 (0)