Skip to content

Commit 88a903a

Browse files
Add helper function to pack dense vectors for efficient uploading
1 parent 5702501 commit 88a903a

File tree

7 files changed

+202
-0
lines changed

7 files changed

+202
-0
lines changed

docs/sphinx/api_helpers.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ Bulk
1717
----
1818
.. autofunction:: bulk
1919

20+
Dense Vector packing
21+
--------------------
22+
.. autofunction:: pack_dense_vector
23+
2024
Scan
2125
----
2226
.. autofunction:: scan

elasticsearch/helpers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
BULK_FLUSH,
2424
bulk,
2525
expand_action,
26+
pack_dense_vector,
2627
parallel_bulk,
2728
reindex,
2829
scan,
@@ -37,6 +38,7 @@
3738
"expand_action",
3839
"streaming_bulk",
3940
"bulk",
41+
"pack_dense_vector",
4042
"parallel_bulk",
4143
"scan",
4244
"reindex",

elasticsearch/helpers/actions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import base64
1819
import logging
1920
import queue
2021
import time
2122
from enum import Enum
2223
from operator import methodcaller
2324
from typing import (
25+
TYPE_CHECKING,
2426
Any,
2527
Callable,
2628
Collection,
@@ -31,6 +33,7 @@
3133
Mapping,
3234
MutableMapping,
3335
Optional,
36+
Sequence,
3437
Tuple,
3538
Union,
3639
)
@@ -43,6 +46,9 @@
4346
from ..serializer import Serializer
4447
from .errors import BulkIndexError, ScanError
4548

49+
if TYPE_CHECKING:
50+
import numpy as np
51+
4652
logger = logging.getLogger("elasticsearch.helpers")
4753

4854

@@ -708,6 +714,21 @@ def _setup_queues(self) -> None:
708714
pool.join()
709715

710716

717+
def pack_dense_vector(vector: Union["np.ndarray", Sequence[float]]) -> str:
718+
"""Helper function that packs a dense vector for efficient uploading.
719+
720+
:arg v: the list or numpy array to pack.
721+
"""
722+
import numpy as np
723+
724+
if type(vector) is not np.ndarray:
725+
vector = np.array(vector, dtype=np.float32)
726+
elif vector.dtype != np.float32:
727+
raise ValueError("Only arrays of type float32 can be packed")
728+
byte_array = vector.byteswap().tobytes()
729+
return base64.b64encode(byte_array).decode()
730+
731+
711732
def scan(
712733
client: Elasticsearch,
713734
query: Optional[Any] = None,

examples/quotes/backend/quotes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from elasticsearch import NotFoundError, OrjsonSerializer
1313
from elasticsearch.dsl.pydantic import AsyncBaseESModel
1414
from elasticsearch import dsl
15+
from elasticsearch.helpers import pack_dense_vector
1516

1617
model = SentenceTransformer("all-MiniLM-L6-v2")
1718
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']], serializer=OrjsonSerializer())
@@ -33,6 +34,9 @@ class Config:
3334
class Index:
3435
name = 'quotes'
3536

37+
def clean(self):
38+
# pack the embedding for efficient uploading
39+
self.embedding = pack_dense_vector(self.embedding)
3640

3741
class Tag(BaseModel):
3842
tag: str

test_elasticsearch/test_dsl/test_integration/_async/test_document.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from elasticsearch.dsl.query import Match
6060
from elasticsearch.dsl.types import MatchQuery
6161
from elasticsearch.dsl.utils import AttrList
62+
from elasticsearch.helpers import pack_dense_vector
6263
from elasticsearch.helpers.errors import BulkIndexError
6364

6465
snowball = analyzer("my_snow", tokenizer="standard", filter=["lowercase", "snowball"])
@@ -868,10 +869,19 @@ class Doc(AsyncDocument):
868869
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
869870
bit_vector: List[int] = mapped_field(DenseVector(element_type="bit"))
870871
numpy_float_vector: np.ndarray = mapped_field(NumpyDenseVector())
872+
packed_float_vector: List[float] = mapped_field(DenseVector())
873+
packed_numpy_float_vector: np.ndarray = mapped_field(NumpyDenseVector())
871874

872875
class Index:
873876
name = "vectors"
874877

878+
def clean(self):
879+
# pack the dense vectors before they are sent to Elasticsearch
880+
self.packed_float_vector = pack_dense_vector(self.packed_float_vector)
881+
self.packed_numpy_float_vector = pack_dense_vector(
882+
self.packed_numpy_float_vector
883+
)
884+
875885
await Doc._index.delete(ignore_unavailable=True)
876886
await Doc.init()
877887

@@ -884,6 +894,8 @@ class Index:
884894
byte_vector=test_byte_vector,
885895
bit_vector=test_bit_vector,
886896
numpy_float_vector=np.array(test_float_vector),
897+
packed_float_vector=test_float_vector,
898+
packed_numpy_float_vector=np.array(test_float_vector, dtype=np.float32),
887899
)
888900
await doc.save(refresh=True)
889901

@@ -894,6 +906,9 @@ class Index:
894906
assert docs[0].bit_vector == test_bit_vector
895907
assert type(docs[0].numpy_float_vector) is np.ndarray
896908
assert [round(v, 1) for v in docs[0].numpy_float_vector] == test_float_vector
909+
assert [round(v, 1) for v in docs[0].packed_float_vector] == test_float_vector
910+
assert type(docs[0].packed_numpy_float_vector) is np.ndarray
911+
assert [round(v, 1) for v in docs[0].packed_numpy_float_vector] == test_float_vector
897912

898913

899914
@pytest.mark.anyio

test_elasticsearch/test_dsl/test_integration/_sync/test_document.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from elasticsearch.dsl.query import Match
6060
from elasticsearch.dsl.types import MatchQuery
6161
from elasticsearch.dsl.utils import AttrList
62+
from elasticsearch.helpers import pack_dense_vector
6263
from elasticsearch.helpers.errors import BulkIndexError
6364

6465
snowball = analyzer("my_snow", tokenizer="standard", filter=["lowercase", "snowball"])
@@ -856,10 +857,19 @@ class Doc(Document):
856857
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
857858
bit_vector: List[int] = mapped_field(DenseVector(element_type="bit"))
858859
numpy_float_vector: np.ndarray = mapped_field(NumpyDenseVector())
860+
packed_float_vector: List[float] = mapped_field(DenseVector())
861+
packed_numpy_float_vector: np.ndarray = mapped_field(NumpyDenseVector())
859862

860863
class Index:
861864
name = "vectors"
862865

866+
def clean(self):
867+
# pack the dense vectors before they are sent to Elasticsearch
868+
self.packed_float_vector = pack_dense_vector(self.packed_float_vector)
869+
self.packed_numpy_float_vector = pack_dense_vector(
870+
self.packed_numpy_float_vector
871+
)
872+
863873
Doc._index.delete(ignore_unavailable=True)
864874
Doc.init()
865875

@@ -872,6 +882,8 @@ class Index:
872882
byte_vector=test_byte_vector,
873883
bit_vector=test_bit_vector,
874884
numpy_float_vector=np.array(test_float_vector),
885+
packed_float_vector=test_float_vector,
886+
packed_numpy_float_vector=np.array(test_float_vector, dtype=np.float32),
875887
)
876888
doc.save(refresh=True)
877889

@@ -882,6 +894,9 @@ class Index:
882894
assert docs[0].bit_vector == test_bit_vector
883895
assert type(docs[0].numpy_float_vector) is np.ndarray
884896
assert [round(v, 1) for v in docs[0].numpy_float_vector] == test_float_vector
897+
assert [round(v, 1) for v in docs[0].packed_float_vector] == test_float_vector
898+
assert type(docs[0].packed_numpy_float_vector) is np.ndarray
899+
assert [round(v, 1) for v in docs[0].packed_numpy_float_vector] == test_float_vector
885900

886901

887902
@pytest.mark.sync

utils/dense-vector-benchmark.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import argparse
19+
import asyncio
20+
import json
21+
import os
22+
import time
23+
24+
import numpy as np
25+
26+
from elasticsearch import OrjsonSerializer
27+
from elasticsearch.dsl import AsyncDocument, NumpyDenseVector, async_connections
28+
from elasticsearch.dsl.types import DenseVectorIndexOptions
29+
from elasticsearch.helpers import async_bulk, pack_dense_vector
30+
31+
async_connections.create_connection(
32+
hosts=[os.environ["ELASTICSEARCH_URL"]], serializer=OrjsonSerializer()
33+
)
34+
35+
36+
class Doc(AsyncDocument):
37+
title: str
38+
text: str
39+
emb: np.ndarray = NumpyDenseVector(
40+
dtype=np.float32, index_options=DenseVectorIndexOptions(type="flat")
41+
)
42+
43+
class Index:
44+
name = "benchmark"
45+
46+
47+
async def upload(data_file: str, chunk_size: int, pack: bool) -> tuple[float, float]:
48+
with open(data_file, "rt") as f:
49+
# read the data file, which comes in ndjson format and convert it to JSON
50+
json_data = "[" + f.read().strip().replace("\n", ",") + "]"
51+
dataset = json.loads(json_data)
52+
53+
# replace the embedding lists with numpy arrays for performance
54+
dataset = [
55+
{
56+
"docid": doc["docid"],
57+
"title": doc["title"],
58+
"text": doc["text"],
59+
"emb": np.array(doc["emb"], dtype=np.float32),
60+
}
61+
for doc in dataset
62+
]
63+
64+
# create mapping and index
65+
if await Doc._index.exists():
66+
await Doc._index.delete()
67+
await Doc.init()
68+
await Doc._index.refresh()
69+
70+
async def get_next_document():
71+
for doc in dataset:
72+
yield {
73+
"_index": "benchmark",
74+
"_id": doc["docid"],
75+
"_source": {
76+
"title": doc["title"],
77+
"text": doc["text"],
78+
"emb": doc["emb"],
79+
},
80+
}
81+
82+
async def get_next_document_packed():
83+
for doc in dataset:
84+
yield {
85+
"_index": "benchmark",
86+
"_id": doc["docid"],
87+
"_source": {
88+
"title": doc["title"],
89+
"text": doc["text"],
90+
"emb": pack_dense_vector(doc["emb"]),
91+
},
92+
}
93+
94+
start = time.time()
95+
result = await async_bulk(
96+
client=async_connections.get_connection(),
97+
chunk_size=chunk_size,
98+
actions=get_next_document_packed() if pack else get_next_document(),
99+
stats_only=True,
100+
)
101+
duration = time.time() - start
102+
assert result[1] == 0
103+
return result[0], duration
104+
105+
106+
async def main():
107+
parser = argparse.ArgumentParser()
108+
parser.add_argument("data_file", metavar="JSON_DATA_FILE")
109+
parser.add_argument(
110+
"--chunk-sizes", "-s", nargs="+", help="Chunk size(s) for bulk uploader"
111+
)
112+
args = parser.parse_args()
113+
114+
for chunk_size in args.chunk_sizes:
115+
print(f"Uploading '{args.data_file}' with chunk size {chunk_size}...")
116+
runs = []
117+
packed_runs = []
118+
for _ in range(3):
119+
runs.append(await upload(args.data_file, chunk_size, False))
120+
packed_runs.append(await upload(args.data_file, chunk_size, True))
121+
122+
# ensure that all runs uploaded the same number of documents
123+
size = runs[0][0]
124+
for run in runs:
125+
assert run[0] == size
126+
for run in packed_runs:
127+
assert run[0] == size
128+
129+
dur = sum([run[1] for run in runs]) / len(runs)
130+
packed_dur = sum([run[1] for run in packed_runs]) / len(packed_runs)
131+
132+
print(f"Size: {size}")
133+
print(f"float duration: {dur:.02f}s / {size / dur:.02f} docs/s")
134+
print(
135+
f"float base64 duration: {packed_dur:.02f}s / {size / packed_dur:.02f} docs/s"
136+
)
137+
print(f"Speed up: {dur / packed_dur:.02f}x")
138+
139+
140+
if __name__ == "__main__":
141+
asyncio.run(main())

0 commit comments

Comments
 (0)