Skip to content

Commit 3d3c790

Browse files
numpy deserialization
1 parent 195f2ce commit 3d3c790

File tree

3 files changed

+35
-4
lines changed

3 files changed

+35
-4
lines changed

elasticsearch/dsl/field.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,6 +1555,7 @@ class DenseVector(Field):
15551555
:arg dynamic:
15561556
:arg fields:
15571557
:arg synthetic_source_keep:
1558+
:arg use_numpy: if set to ``True``, deserialize as a numpy array.
15581559
"""
15591560

15601561
name = "dense_vector"
@@ -1587,6 +1588,7 @@ def __init__(
15871588
synthetic_source_keep: Union[
15881589
Literal["none", "arrays", "all"], "DefaultType"
15891590
] = DEFAULT,
1591+
use_numpy: bool = False,
15901592
**kwargs: Any,
15911593
):
15921594
if dims is not DEFAULT:
@@ -1614,6 +1616,7 @@ def __init__(
16141616
self._element_type = kwargs.get("element_type", "float")
16151617
if self._element_type in ["float", "byte"]:
16161618
kwargs["multi"] = True
1619+
self._use_numpy = use_numpy
16171620
super().__init__(*args, **kwargs)
16181621

16191622
def clean(self, data: Any) -> Any:
@@ -1623,6 +1626,13 @@ def clean(self, data: Any) -> Any:
16231626
raise ValidationException("Value required for this field.")
16241627
return data
16251628

1629+
def deserialize(self, data: Any) -> Any:
1630+
if self._use_numpy and isinstance(data, list):
1631+
import numpy as np
1632+
1633+
return np.array(data)
1634+
return super().deserialize(data)
1635+
16261636

16271637
class Double(Float):
16281638
"""

examples/quotes/backend/quotes.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,30 @@
55
from typing import Annotated
66

77
from fastapi import FastAPI, HTTPException
8-
from pydantic import BaseModel, Field, ValidationError
8+
import numpy as np
9+
from pydantic import BaseModel, Field, PlainSerializer
910
from sentence_transformers import SentenceTransformer
1011

11-
from elasticsearch import NotFoundError
12+
from elasticsearch import NotFoundError, OrjsonSerializer
1213
from elasticsearch.dsl.pydantic import AsyncBaseESModel
1314
from elasticsearch import dsl
1415

1516
model = SentenceTransformer("all-MiniLM-L6-v2")
16-
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']])
17+
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']], serializer=OrjsonSerializer())
1718

1819

1920
class Quote(AsyncBaseESModel):
2021
quote: str
2122
author: Annotated[str, dsl.Keyword()]
2223
tags: Annotated[list[str], dsl.Keyword()]
23-
embedding: Annotated[list[float], dsl.DenseVector()] = Field(init=False, default=[])
24+
embedding: Annotated[
25+
np.ndarray,
26+
PlainSerializer(lambda v: v.tolist()),
27+
dsl.DenseVector(use_numpy=True)
28+
] = Field(init=False, default_factory=lambda: np.array([]))
29+
30+
class Config:
31+
arbitrary_types_allowed = True
2432

2533
class Index:
2634
name = 'quotes'

utils/templates/field.py.tpl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ class {{ k.name }}({{ k.parent }}):
217217
{% endfor %}
218218
{% endfor %}
219219
{% endif %}
220+
{% if k.field == "dense_vector" %}
221+
:arg use_numpy: if set to ``True``, deserialize as a numpy array.
222+
{% endif %}
220223
"""
221224
name = "{{ k.field }}"
222225
{% if k.coerced %}
@@ -246,6 +249,9 @@ class {{ k.name }}({{ k.parent }}):
246249
{{ arg.name }}: {{ arg.type }} = DEFAULT,
247250
{% endif %}
248251
{% endfor %}
252+
{% if k.field == "dense_vector" %}
253+
use_numpy: bool = False,
254+
{% endif %}
249255
**kwargs: Any
250256
):
251257
{% for arg in k.args %}
@@ -416,6 +422,7 @@ class {{ k.name }}({{ k.parent }}):
416422
self._element_type = kwargs.get("element_type", "float")
417423
if self._element_type in ["float", "byte"]:
418424
kwargs["multi"] = True
425+
self._use_numpy = use_numpy
419426
super().__init__(*args, **kwargs)
420427

421428
def clean(self, data: Any) -> Any:
@@ -424,6 +431,12 @@ class {{ k.name }}({{ k.parent }}):
424431
if (data is None or len(data) == 0) and self._required:
425432
raise ValidationException("Value required for this field.")
426433
return data
434+
435+
def deserialize(self, data: Any) -> Any:
436+
if self._use_numpy and isinstance(data, list):
437+
import numpy as np
438+
return np.array(data)
439+
return super().deserialize(data)
427440
{% elif k.field == "scaled_float" %}
428441
if 'scaling_factor' not in kwargs:
429442
if len(args) > 0:

0 commit comments

Comments
 (0)