Skip to content

Commit 356ba59

Browse files
authored
feat: add NumPy array serialization support in SQLSpec plugin (#165)
Adds automatic NumPy array serialization to SQLSpec's Litestar plugin, enabling seamless bidirectional conversion between NumPy arrays and JSON for vector embedding workflows.
1 parent df5ee30 commit 356ba59

File tree

6 files changed

+593
-25
lines changed

6 files changed

+593
-25
lines changed

CLAUDE.md

Lines changed: 0 additions & 16 deletions
This file was deleted.

CLAUDE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
AGENTS.md

GEMINI.md

Lines changed: 0 additions & 5 deletions
This file was deleted.

GEMINI.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
AGENTS.md

sqlspec/extensions/litestar/plugin.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
pool_provider_maker,
2626
session_provider_maker,
2727
)
28-
from sqlspec.typing import ConnectionT, PoolT
28+
from sqlspec.typing import NUMPY_INSTALLED, ConnectionT, PoolT, SchemaT
2929
from sqlspec.utils.logging import get_logger
30+
from sqlspec.utils.serializers import numpy_array_dec_hook, numpy_array_enc_hook, numpy_array_predicate
3031

3132
if TYPE_CHECKING:
3233
from collections.abc import AsyncGenerator, Callable
@@ -82,6 +83,10 @@ class _PluginConfigState:
8283
class SQLSpecPlugin(InitPluginProtocol, CLIPlugin):
8384
"""Litestar plugin for SQLSpec database integration.
8485
86+
Automatically configures NumPy array serialization when NumPy is installed,
87+
enabling seamless bidirectional conversion between NumPy arrays and JSON
88+
for vector embedding workflows.
89+
8590
Session Table Migrations:
8691
The Litestar extension includes migrations for creating session storage tables.
8792
To include these migrations in your database migration workflow, add 'litestar'
@@ -225,6 +230,8 @@ def on_cli_init(self, cli: "Group") -> None:
225230
def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
226231
"""Configure Litestar application with SQLSpec database integration.
227232
233+
Automatically registers NumPy array serialization when NumPy is installed.
234+
228235
Args:
229236
app_config: The Litestar application configuration instance.
230237
@@ -239,7 +246,7 @@ def store_sqlspec_in_state() -> None:
239246
app_config.on_startup.append(store_sqlspec_in_state)
240247
app_config.signature_types.extend([SQLSpec, DatabaseConfigProtocol, SyncConfigT, AsyncConfigT])
241248

242-
signature_namespace = {"ConnectionT": ConnectionT, "PoolT": PoolT, "DriverT": DriverT}
249+
signature_namespace = {"ConnectionT": ConnectionT, "PoolT": PoolT, "DriverT": DriverT, "SchemaT": SchemaT}
243250

244251
for state in self._plugin_configs:
245252
state.annotation = type(state.config)
@@ -262,6 +269,23 @@ def store_sqlspec_in_state() -> None:
262269
if signature_namespace:
263270
app_config.signature_namespace.update(signature_namespace)
264271

272+
if NUMPY_INSTALLED:
273+
import numpy as np
274+
275+
if app_config.type_encoders is None:
276+
app_config.type_encoders = {np.ndarray: numpy_array_enc_hook}
277+
else:
278+
encoders_dict = dict(app_config.type_encoders)
279+
encoders_dict[np.ndarray] = numpy_array_enc_hook
280+
app_config.type_encoders = encoders_dict
281+
282+
if app_config.type_decoders is None:
283+
app_config.type_decoders = [(numpy_array_predicate, numpy_array_dec_hook)] # type: ignore[list-item]
284+
else:
285+
decoders_list = list(app_config.type_decoders)
286+
decoders_list.append((numpy_array_predicate, numpy_array_dec_hook)) # type: ignore[arg-type]
287+
app_config.type_decoders = decoders_list
288+
265289
return app_config
266290

267291
def get_annotations(

sqlspec/utils/serializers.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22
33
Re-exports common JSON encoding and decoding functions from the core
44
serialization module for convenient access.
5+
6+
Provides NumPy array serialization hooks for framework integrations
7+
that support custom type encoders and decoders (e.g., Litestar).
58
"""
69

710
from typing import Any, Literal, overload
811

912
from sqlspec._serialization import decode_json, encode_json
13+
from sqlspec.typing import NUMPY_INSTALLED
1014

1115

1216
@overload
@@ -55,4 +59,109 @@ def from_json(data: str | bytes, *, decode_bytes: bool = True) -> Any:
5559
return decode_json(data)
5660

5761

58-
__all__ = ("from_json", "to_json")
62+
def numpy_array_enc_hook(value: Any) -> Any:
63+
"""Encode NumPy array to JSON-compatible list.
64+
65+
Converts NumPy ndarrays to Python lists for JSON serialization.
66+
Gracefully handles cases where NumPy is not installed by returning
67+
the original value unchanged.
68+
69+
Args:
70+
value: Value to encode (checked for ndarray type).
71+
72+
Returns:
73+
List representation if value is ndarray, original value otherwise.
74+
75+
Example:
76+
>>> import numpy as np
77+
>>> arr = np.array([1.0, 2.0, 3.0])
78+
>>> numpy_array_enc_hook(arr)
79+
[1.0, 2.0, 3.0]
80+
81+
>>> # Multi-dimensional arrays work automatically
82+
>>> arr_2d = np.array([[1, 2], [3, 4]])
83+
>>> numpy_array_enc_hook(arr_2d)
84+
[[1, 2], [3, 4]]
85+
"""
86+
if not NUMPY_INSTALLED:
87+
return value
88+
89+
import numpy as np
90+
91+
if isinstance(value, np.ndarray):
92+
return value.tolist()
93+
return value
94+
95+
96+
def numpy_array_dec_hook(value: Any) -> "Any":
97+
"""Decode list to NumPy array.
98+
99+
Converts Python lists to NumPy arrays when appropriate.
100+
Works best with typed schemas (Pydantic, msgspec) that expect ndarray.
101+
102+
Args:
103+
value: List to potentially convert to ndarray.
104+
105+
Returns:
106+
NumPy array if conversion successful, original value otherwise.
107+
108+
Note:
109+
Dtype is inferred by NumPy and may differ from original array.
110+
For explicit dtype control, construct arrays manually in application code.
111+
112+
Example:
113+
>>> numpy_array_dec_hook([1.0, 2.0, 3.0])
114+
array([1., 2., 3.])
115+
116+
>>> # Returns original value if NumPy not installed
117+
>>> # (when NUMPY_INSTALLED is False)
118+
>>> numpy_array_dec_hook([1, 2, 3])
119+
[1, 2, 3]
120+
"""
121+
if not NUMPY_INSTALLED:
122+
return value
123+
124+
import numpy as np
125+
126+
if isinstance(value, list):
127+
try:
128+
return np.array(value)
129+
except Exception:
130+
return value
131+
return value
132+
133+
134+
def numpy_array_predicate(value: Any) -> bool:
135+
"""Check if value is NumPy array instance.
136+
137+
Type checker for decoder registration in framework plugins.
138+
Returns False when NumPy is not installed.
139+
140+
Args:
141+
value: Value to type-check.
142+
143+
Returns:
144+
True if value is ndarray, False otherwise.
145+
146+
Example:
147+
>>> import numpy as np
148+
>>> numpy_array_predicate(np.array([1, 2, 3]))
149+
True
150+
151+
>>> numpy_array_predicate([1, 2, 3])
152+
False
153+
154+
>>> # Returns False when NumPy not installed
155+
>>> # (when NUMPY_INSTALLED is False)
156+
>>> numpy_array_predicate([1, 2, 3])
157+
False
158+
"""
159+
if not NUMPY_INSTALLED:
160+
return False
161+
162+
import numpy as np
163+
164+
return isinstance(value, np.ndarray)
165+
166+
167+
__all__ = ("from_json", "numpy_array_dec_hook", "numpy_array_enc_hook", "numpy_array_predicate", "to_json")

0 commit comments

Comments
 (0)