Skip to content

Commit 8c198a2

Browse files
Add helper function to pack dense vectors for efficient uploading
1 parent 5702501 commit 8c198a2

File tree

6 files changed

+57
-0
lines changed

6 files changed

+57
-0
lines changed

docs/sphinx/api_helpers.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ Bulk
1717
----
1818
.. autofunction:: bulk
1919

20+
Dense Vector packing
21+
--------------------
22+
.. autofunction:: pack_dense_vector
23+
2024
Scan
2125
----
2226
.. autofunction:: scan

elasticsearch/helpers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
BULK_FLUSH,
2424
bulk,
2525
expand_action,
26+
pack_dense_vector,
2627
parallel_bulk,
2728
reindex,
2829
scan,
@@ -37,6 +38,7 @@
3738
"expand_action",
3839
"streaming_bulk",
3940
"bulk",
41+
"pack_dense_vector",
4042
"parallel_bulk",
4143
"scan",
4244
"reindex",

elasticsearch/helpers/actions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import base64
1819
import logging
1920
import queue
2021
import time
@@ -31,6 +32,7 @@
3132
Mapping,
3233
MutableMapping,
3334
Optional,
35+
Sequence,
3436
Tuple,
3537
Union,
3638
)
@@ -708,6 +710,21 @@ def _setup_queues(self) -> None:
708710
pool.join()
709711

710712

713+
def pack_dense_vector(vector: Union["np.ndarray", Sequence[float]]) -> str:
714+
"""Helper function that packs a dense vector for efficient uploading.
715+
716+
:arg v: the list or numpy array to pack.
717+
"""
718+
import numpy as np
719+
720+
if type(vector) is not np.ndarray:
721+
vector = np.array(vector, dtype=np.float32)
722+
elif vector.dtype != np.float32:
723+
raise ValueError("Only arrays of type float32 can be packed")
724+
byte_array = vector.byteswap().tobytes()
725+
return base64.b64encode(byte_array).decode()
726+
727+
711728
def scan(
712729
client: Elasticsearch,
713730
query: Optional[Any] = None,

examples/quotes/backend/quotes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from elasticsearch import NotFoundError, OrjsonSerializer
1313
from elasticsearch.dsl.pydantic import AsyncBaseESModel
1414
from elasticsearch import dsl
15+
from elasticsearch.helpers import pack_dense_vector
1516

1617
model = SentenceTransformer("all-MiniLM-L6-v2")
1718
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']], serializer=OrjsonSerializer())
@@ -33,6 +34,9 @@ class Config:
3334
class Index:
3435
name = 'quotes'
3536

37+
def clean(self):
38+
# pack the embedding for efficient uploading
39+
self.embedding = pack_dense_vector(self.embedding)
3640

3741
class Tag(BaseModel):
3842
tag: str

test_elasticsearch/test_dsl/test_integration/_async/test_document.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from elasticsearch.dsl.query import Match
6060
from elasticsearch.dsl.types import MatchQuery
6161
from elasticsearch.dsl.utils import AttrList
62+
from elasticsearch.helpers import pack_dense_vector
6263
from elasticsearch.helpers.errors import BulkIndexError
6364

6465
snowball = analyzer("my_snow", tokenizer="standard", filter=["lowercase", "snowball"])
@@ -868,10 +869,19 @@ class Doc(AsyncDocument):
868869
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
869870
bit_vector: List[int] = mapped_field(DenseVector(element_type="bit"))
870871
numpy_float_vector: np.ndarray = mapped_field(NumpyDenseVector())
872+
packed_float_vector: List[float] = mapped_field(DenseVector())
873+
packed_numpy_float_vector: np.ndarray = mapped_field(NumpyDenseVector())
871874

872875
class Index:
873876
name = "vectors"
874877

878+
def clean(self):
879+
# pack the dense vectors before they are sent to Elasticsearch
880+
self.packed_float_vector = pack_dense_vector(self.packed_float_vector)
881+
self.packed_numpy_float_vector = pack_dense_vector(
882+
self.packed_numpy_float_vector
883+
)
884+
875885
await Doc._index.delete(ignore_unavailable=True)
876886
await Doc.init()
877887

@@ -884,6 +894,8 @@ class Index:
884894
byte_vector=test_byte_vector,
885895
bit_vector=test_bit_vector,
886896
numpy_float_vector=np.array(test_float_vector),
897+
packed_float_vector=test_float_vector,
898+
packed_numpy_float_vector=np.array(test_float_vector, dtype=np.float32),
887899
)
888900
await doc.save(refresh=True)
889901

@@ -894,6 +906,9 @@ class Index:
894906
assert docs[0].bit_vector == test_bit_vector
895907
assert type(docs[0].numpy_float_vector) is np.ndarray
896908
assert [round(v, 1) for v in docs[0].numpy_float_vector] == test_float_vector
909+
assert [round(v, 1) for v in docs[0].packed_float_vector] == test_float_vector
910+
assert type(docs[0].packed_numpy_float_vector) is np.ndarray
911+
assert [round(v, 1) for v in docs[0].packed_numpy_float_vector] == test_float_vector
897912

898913

899914
@pytest.mark.anyio

test_elasticsearch/test_dsl/test_integration/_sync/test_document.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from elasticsearch.dsl.query import Match
6060
from elasticsearch.dsl.types import MatchQuery
6161
from elasticsearch.dsl.utils import AttrList
62+
from elasticsearch.helpers import pack_dense_vector
6263
from elasticsearch.helpers.errors import BulkIndexError
6364

6465
snowball = analyzer("my_snow", tokenizer="standard", filter=["lowercase", "snowball"])
@@ -856,10 +857,19 @@ class Doc(Document):
856857
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
857858
bit_vector: List[int] = mapped_field(DenseVector(element_type="bit"))
858859
numpy_float_vector: np.ndarray = mapped_field(NumpyDenseVector())
860+
packed_float_vector: List[float] = mapped_field(DenseVector())
861+
packed_numpy_float_vector: np.ndarray = mapped_field(NumpyDenseVector())
859862

860863
class Index:
861864
name = "vectors"
862865

866+
def clean(self):
867+
# pack the dense vectors before they are sent to Elasticsearch
868+
self.packed_float_vector = pack_dense_vector(self.packed_float_vector)
869+
self.packed_numpy_float_vector = pack_dense_vector(
870+
self.packed_numpy_float_vector
871+
)
872+
863873
Doc._index.delete(ignore_unavailable=True)
864874
Doc.init()
865875

@@ -872,6 +882,8 @@ class Index:
872882
byte_vector=test_byte_vector,
873883
bit_vector=test_bit_vector,
874884
numpy_float_vector=np.array(test_float_vector),
885+
packed_float_vector=test_float_vector,
886+
packed_numpy_float_vector=np.array(test_float_vector, dtype=np.float32),
875887
)
876888
doc.save(refresh=True)
877889

@@ -882,6 +894,9 @@ class Index:
882894
assert docs[0].bit_vector == test_bit_vector
883895
assert type(docs[0].numpy_float_vector) is np.ndarray
884896
assert [round(v, 1) for v in docs[0].numpy_float_vector] == test_float_vector
897+
assert [round(v, 1) for v in docs[0].packed_float_vector] == test_float_vector
898+
assert type(docs[0].packed_numpy_float_vector) is np.ndarray
899+
assert [round(v, 1) for v in docs[0].packed_numpy_float_vector] == test_float_vector
885900

886901

887902
@pytest.mark.sync

0 commit comments

Comments
 (0)