Skip to content

Commit 66ae37d

Browse files
numpy deserialization
1 parent 195f2ce commit 66ae37d

File tree

4 files changed

+50
-8
lines changed

4 files changed

+50
-8
lines changed

elasticsearch/dsl/field.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,6 +1555,7 @@ class DenseVector(Field):
15551555
:arg dynamic:
15561556
:arg fields:
15571557
:arg synthetic_source_keep:
1558+
:arg use_numpy: if set to ``True``, deserialize as a numpy array.
15581559
"""
15591560

15601561
name = "dense_vector"
@@ -1587,6 +1588,7 @@ def __init__(
15871588
synthetic_source_keep: Union[
15881589
Literal["none", "arrays", "all"], "DefaultType"
15891590
] = DEFAULT,
1591+
use_numpy: bool = False,
15901592
**kwargs: Any,
15911593
):
15921594
if dims is not DEFAULT:
@@ -1614,9 +1616,19 @@ def __init__(
16141616
self._element_type = kwargs.get("element_type", "float")
16151617
if self._element_type in ["float", "byte"]:
16161618
kwargs["multi"] = True
1619+
self._use_numpy = use_numpy
16171620
super().__init__(*args, **kwargs)
16181621

1622+
def deserialize(self, data: Any) -> Any:
1623+
if self._use_numpy and isinstance(data, list):
1624+
import numpy as np
1625+
1626+
return np.array(data)
1627+
return super().deserialize(data)
1628+
16191629
def clean(self, data: Any) -> Any:
1630+
# this method does the same as the one in the parent class, but it
1631+
# avoids comparisons that break when data is a numpy array
16201632
if data is not None:
16211633
data = self.deserialize(data)
16221634
if (data is None or len(data) == 0) and self._required:

elasticsearch/dsl/utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -612,10 +612,17 @@ def to_dict(self, skip_empty: bool = True) -> Dict[str, Any]:
612612
if skip_empty:
613613
# don't serialize empty values
614614
# careful not to include numeric zeros
615-
# the "is" operator is used below because it is the only comparison
616-
# that works for numpy arrays
617-
if v is [] or v is {} or v is None:
618-
continue
615+
try:
616+
if v in ([], {}, None):
617+
continue
618+
except ValueError:
619+
# the above fails when v is a numpy array
620+
# try using len() instead
621+
try:
622+
if len(v) == 0:
623+
continue
624+
except TypeError:
625+
pass
619626

620627
out[k] = v
621628
return out

examples/quotes/backend/quotes.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,30 @@
55
from typing import Annotated
66

77
from fastapi import FastAPI, HTTPException
8-
from pydantic import BaseModel, Field, ValidationError
8+
import numpy as np
9+
from pydantic import BaseModel, Field, PlainSerializer
910
from sentence_transformers import SentenceTransformer
1011

11-
from elasticsearch import NotFoundError
12+
from elasticsearch import NotFoundError, OrjsonSerializer
1213
from elasticsearch.dsl.pydantic import AsyncBaseESModel
1314
from elasticsearch import dsl
1415

1516
model = SentenceTransformer("all-MiniLM-L6-v2")
16-
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']])
17+
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']], serializer=OrjsonSerializer())
1718

1819

1920
class Quote(AsyncBaseESModel):
2021
quote: str
2122
author: Annotated[str, dsl.Keyword()]
2223
tags: Annotated[list[str], dsl.Keyword()]
23-
embedding: Annotated[list[float], dsl.DenseVector()] = Field(init=False, default=[])
24+
embedding: Annotated[
25+
np.ndarray,
26+
PlainSerializer(lambda v: v.tolist()),
27+
dsl.DenseVector(use_numpy=True)
28+
] = Field(init=False, default_factory=lambda: np.array([]))
29+
30+
class Config:
31+
arbitrary_types_allowed = True
2432

2533
class Index:
2634
name = 'quotes'

utils/templates/field.py.tpl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ class {{ k.name }}({{ k.parent }}):
217217
{% endfor %}
218218
{% endfor %}
219219
{% endif %}
220+
{% if k.field == "dense_vector" %}
221+
:arg use_numpy: if set to ``True``, deserialize as a numpy array.
222+
{% endif %}
220223
"""
221224
name = "{{ k.field }}"
222225
{% if k.coerced %}
@@ -246,6 +249,9 @@ class {{ k.name }}({{ k.parent }}):
246249
{{ arg.name }}: {{ arg.type }} = DEFAULT,
247250
{% endif %}
248251
{% endfor %}
252+
{% if k.field == "dense_vector" %}
253+
use_numpy: bool = False,
254+
{% endif %}
249255
**kwargs: Any
250256
):
251257
{% for arg in k.args %}
@@ -416,9 +422,18 @@ class {{ k.name }}({{ k.parent }}):
416422
self._element_type = kwargs.get("element_type", "float")
417423
if self._element_type in ["float", "byte"]:
418424
kwargs["multi"] = True
425+
self._use_numpy = use_numpy
419426
super().__init__(*args, **kwargs)
420427

428+
def deserialize(self, data: Any) -> Any:
429+
if self._use_numpy and isinstance(data, list):
430+
import numpy as np
431+
return np.array(data)
432+
return super().deserialize(data)
433+
421434
def clean(self, data: Any) -> Any:
435+
# this method does the same as the one in the parent class, but it
436+
# avoids comparisons that break when data is a numpy array
422437
if data is not None:
423438
data = self.deserialize(data)
424439
if (data is None or len(data) == 0) and self._required:

0 commit comments

Comments
 (0)