Skip to content

Commit dc9c7c7

Browse files
Support dense vectors with base64 encoding
1 parent bbca81a commit dc9c7c7

File tree

7 files changed

+109
-29
lines changed

7 files changed

+109
-29
lines changed

elasticsearch/dsl/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
MatchOnlyText,
7474
Murmur3,
7575
Nested,
76+
NumpyDenseVector,
7677
Object,
7778
Passthrough,
7879
Percolator,
@@ -189,6 +190,7 @@
189190
"Murmur3",
190191
"Nested",
191192
"NestedFacet",
193+
"NumpyDenseVector",
192194
"Object",
193195
"Passthrough",
194196
"Percolator",

elasticsearch/dsl/field.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,11 +1616,33 @@ def __init__(
16161616
kwargs["multi"] = True
16171617
super().__init__(*args, **kwargs)
16181618

1619-
def _deserialize(self, data: Any) -> Any:
1620-
if self._element_type == "float":
1621-
return float(data)
1622-
elif self._element_type == "byte":
1623-
return int(data)
1619+
1620+
class NumpyDenseVector(DenseVector):
1621+
"""A dense vector field that uses numpy arrays.
1622+
1623+
Accepts the same arguments as class ``DenseVector`` plus:
1624+
1625+
:arg dtype: The numpy data type to use for the array. If not given, numpy will select the type based on the data.
1626+
"""
1627+
1628+
def __init__(self, *args: Any, dtype: Optional[type] = None, **kwargs: Any):
1629+
super().__init__(*args, **kwargs)
1630+
self._dtype = dtype
1631+
1632+
def deserialize(self, data: Any) -> Any:
1633+
if isinstance(data, list):
1634+
import numpy as np
1635+
1636+
return np.array(data, dtype=self._dtype)
1637+
return super().deserialize(data)
1638+
1639+
def clean(self, data: Any) -> Any:
1640+
# this method does the same as the one in the parent classes, but it
1641+
# avoids comparisons that do not work for numpy arrays
1642+
if data is not None:
1643+
data = self.deserialize(data)
1644+
if (data is None or len(data) == 0) and self._required:
1645+
raise ValidationException("Value required for this field.")
16241646
return data
16251647

16261648

elasticsearch/dsl/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,8 +612,17 @@ def to_dict(self, skip_empty: bool = True) -> Dict[str, Any]:
612612
if skip_empty:
613613
# don't serialize empty values
614614
# careful not to include numeric zeros
615-
if v in ([], {}, None):
616-
continue
615+
try:
616+
if v in ([], {}, None):
617+
continue
618+
except ValueError:
619+
# the above fails when v is a numpy array
620+
# try using len() instead
621+
try:
622+
if len(v) == 0:
623+
continue
624+
except TypeError:
625+
pass
617626

618627
out[k] = v
619628
return out

examples/quotes/backend/quotes.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,30 @@
55
from typing import Annotated
66

77
from fastapi import FastAPI, HTTPException
8-
from pydantic import BaseModel, Field, ValidationError
8+
import numpy as np
9+
from pydantic import BaseModel, Field, PlainSerializer
910
from sentence_transformers import SentenceTransformer
1011

11-
from elasticsearch import NotFoundError
12+
from elasticsearch import NotFoundError, OrjsonSerializer
1213
from elasticsearch.dsl.pydantic import AsyncBaseESModel
1314
from elasticsearch import dsl
1415

1516
model = SentenceTransformer("all-MiniLM-L6-v2")
16-
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']])
17+
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']], serializer=OrjsonSerializer())
1718

1819

1920
class Quote(AsyncBaseESModel):
2021
quote: str
2122
author: Annotated[str, dsl.Keyword()]
2223
tags: Annotated[list[str], dsl.Keyword()]
23-
embedding: Annotated[list[float], dsl.DenseVector()] = Field(init=False, default=[])
24+
embedding: Annotated[
25+
np.ndarray,
26+
PlainSerializer(lambda v: v.tolist()),
27+
dsl.NumpyDenseVector(dtype=np.float32)
28+
] = Field(init=False, default_factory=lambda: np.array([], dtype=np.float32))
29+
30+
class Config:
31+
arbitrary_types_allowed = True
2432

2533
class Index:
2634
name = 'quotes'
@@ -135,7 +143,7 @@ async def search_quotes(req: SearchRequest) -> SearchResponse:
135143
def embed_quotes(quotes):
136144
embeddings = model.encode([q.quote for q in quotes])
137145
for q, e in zip(quotes, embeddings):
138-
q.embedding = e.tolist()
146+
q.embedding = e
139147

140148

141149
async def ingest_quotes():

test_elasticsearch/test_dsl/test_integration/_async/test_document.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ipaddress import ip_address
2626
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Tuple, Union
2727

28+
import numpy as np
2829
import pytest
2930
from pytest import raises
3031
from pytz import timezone
@@ -47,6 +48,7 @@
4748
Mapping,
4849
MetaField,
4950
Nested,
51+
NumpyDenseVector,
5052
Object,
5153
Q,
5254
RankFeatures,
@@ -865,25 +867,33 @@ class Doc(AsyncDocument):
865867
float_vector: List[float] = mapped_field(DenseVector())
866868
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
867869
bit_vector: List[int] = mapped_field(DenseVector(element_type="bit"))
870+
numpy_float_vector: np.ndarray = mapped_field(NumpyDenseVector())
868871

869872
class Index:
870873
name = "vectors"
871874

872875
await Doc._index.delete(ignore_unavailable=True)
873876
await Doc.init()
874877

878+
test_float_vector = [1.0, 1.2, 2.3]
879+
test_byte_vector = [12, 23, 34, 45]
880+
test_bit_vector = [18, -43, -112]
881+
875882
doc = Doc(
876-
float_vector=[1.0, 1.2, 2.3],
877-
byte_vector=[12, 23, 34, 45],
878-
bit_vector=[18, -43, -112],
883+
float_vector=test_float_vector,
884+
byte_vector=test_byte_vector,
885+
bit_vector=test_bit_vector,
886+
numpy_float_vector=np.array(test_float_vector),
879887
)
880888
await doc.save(refresh=True)
881889

882890
docs = await Doc.search().execute()
883891
assert len(docs) == 1
884-
assert [round(v, 1) for v in docs[0].float_vector] == doc.float_vector
885-
assert docs[0].byte_vector == doc.byte_vector
886-
assert docs[0].bit_vector == doc.bit_vector
892+
assert [round(v, 1) for v in docs[0].float_vector] == test_float_vector
893+
assert docs[0].byte_vector == test_byte_vector
894+
assert docs[0].bit_vector == test_bit_vector
895+
assert type(docs[0].numpy_float_vector) is np.ndarray
896+
assert [round(v, 1) for v in docs[0].numpy_float_vector] == test_float_vector
887897

888898

889899
@pytest.mark.anyio

test_elasticsearch/test_dsl/test_integration/_sync/test_document.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ipaddress import ip_address
2626
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union
2727

28+
import numpy as np
2829
import pytest
2930
from pytest import raises
3031
from pytz import timezone
@@ -46,6 +47,7 @@
4647
Mapping,
4748
MetaField,
4849
Nested,
50+
NumpyDenseVector,
4951
Object,
5052
Q,
5153
RankFeatures,
@@ -853,25 +855,33 @@ class Doc(Document):
853855
float_vector: List[float] = mapped_field(DenseVector())
854856
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
855857
bit_vector: List[int] = mapped_field(DenseVector(element_type="bit"))
858+
numpy_float_vector: np.ndarray = mapped_field(NumpyDenseVector())
856859

857860
class Index:
858861
name = "vectors"
859862

860863
Doc._index.delete(ignore_unavailable=True)
861864
Doc.init()
862865

866+
test_float_vector = [1.0, 1.2, 2.3]
867+
test_byte_vector = [12, 23, 34, 45]
868+
test_bit_vector = [18, -43, -112]
869+
863870
doc = Doc(
864-
float_vector=[1.0, 1.2, 2.3],
865-
byte_vector=[12, 23, 34, 45],
866-
bit_vector=[18, -43, -112],
871+
float_vector=test_float_vector,
872+
byte_vector=test_byte_vector,
873+
bit_vector=test_bit_vector,
874+
numpy_float_vector=np.array(test_float_vector),
867875
)
868876
doc.save(refresh=True)
869877

870878
docs = Doc.search().execute()
871879
assert len(docs) == 1
872-
assert [round(v, 1) for v in docs[0].float_vector] == doc.float_vector
873-
assert docs[0].byte_vector == doc.byte_vector
874-
assert docs[0].bit_vector == doc.bit_vector
880+
assert [round(v, 1) for v in docs[0].float_vector] == test_float_vector
881+
assert docs[0].byte_vector == test_byte_vector
882+
assert docs[0].bit_vector == test_bit_vector
883+
assert type(docs[0].numpy_float_vector) is np.ndarray
884+
assert [round(v, 1) for v in docs[0].numpy_float_vector] == test_float_vector
875885

876886

877887
@pytest.mark.sync

utils/templates/field.py.tpl

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -418,11 +418,30 @@ class {{ k.name }}({{ k.parent }}):
418418
kwargs["multi"] = True
419419
super().__init__(*args, **kwargs)
420420

421-
def _deserialize(self, data: Any) -> Any:
422-
if self._element_type == "float":
423-
return float(data)
424-
elif self._element_type == "byte":
425-
return int(data)
421+
class NumpyDenseVector(DenseVector):
422+
"""A dense vector field that uses numpy arrays.
423+
424+
Accepts the same arguments as class ``DenseVector`` plus:
425+
426+
:arg dtype: The numpy data type to use for the array. If not given, numpy will select the type based on the data.
427+
"""
428+
def __init__(self, *args: Any, dtype: Optional[type] = None, **kwargs: Any):
429+
super().__init__(*args, **kwargs)
430+
self._dtype = dtype
431+
432+
def deserialize(self, data: Any) -> Any:
433+
if isinstance(data, list):
434+
import numpy as np
435+
return np.array(data, dtype=self._dtype)
436+
return super().deserialize(data)
437+
438+
def clean(self, data: Any) -> Any:
439+
# this method does the same as the one in the parent classes, but it
440+
# avoids comparisons that do not work for numpy arrays
441+
if data is not None:
442+
data = self.deserialize(data)
443+
if (data is None or len(data) == 0) and self._required:
444+
raise ValidationException("Value required for this field.")
426445
return data
427446
{% elif k.field == "scaled_float" %}
428447
if 'scaling_factor' not in kwargs:

0 commit comments

Comments
 (0)