Skip to content

Commit 48d7f9b

Browse files
support new dense vector quantization in 8.16
1 parent 0dd69f8 commit 48d7f9b

File tree

3 files changed

+64
-2
lines changed

3 files changed

+64
-2
lines changed

elasticsearch_dsl/field.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,13 +389,23 @@ def _deserialize(self, data: Any) -> float:
389389
return float(data)
390390

391391

392-
class DenseVector(Float):
392+
class DenseVector(Field):
393393
name = "dense_vector"
394+
_coerce = True
394395

395396
def __init__(self, **kwargs: Any):
396-
kwargs["multi"] = True
397+
self._element_type = kwargs.get("element_type", "float")
398+
if self._element_type in ["float", "byte"]:
399+
kwargs["multi"] = True
397400
super().__init__(**kwargs)
398401

402+
def _deserialize(self, data: Any) -> float:
403+
if self._element_type == "float":
404+
return float(data)
405+
elif self._element_type == "byte":
406+
return int(data)
407+
return data
408+
399409

400410
class SparseVector(Field):
401411
name = "sparse_vector"

tests/test_integration/_async/test_document.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
Binary,
3838
Boolean,
3939
Date,
40+
DenseVector,
4041
Double,
4142
InnerDoc,
4243
Ip,
@@ -795,3 +796,28 @@ async def gen3() -> AsyncIterator[Union[Doc, Dict[str, Any]]]:
795796
"age": 45,
796797
"languages": ["es"],
797798
}
799+
800+
801+
@pytest.mark.asyncio
802+
async def test_dense_vector_quantization(async_client: AsyncElasticsearch) -> None:
803+
class Doc(AsyncDocument):
804+
float_vector: List[float] = mapped_field(DenseVector())
805+
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
806+
bit_vector: str = mapped_field(DenseVector(element_type="bit"))
807+
808+
class Index:
809+
name = "vectors"
810+
811+
await Doc._index.delete(ignore_unavailable=True)
812+
await Doc.init()
813+
814+
doc = Doc(
815+
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
816+
)
817+
await doc.save(refresh=True)
818+
819+
docs = await Doc.search().execute()
820+
assert len(docs) == 1
821+
assert docs[0].float_vector == doc.float_vector
822+
assert docs[0].byte_vector == doc.byte_vector
823+
assert docs[0].bit_vector == doc.bit_vector

tests/test_integration/_sync/test_document.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
Binary,
3636
Boolean,
3737
Date,
38+
DenseVector,
3839
Document,
3940
Double,
4041
InnerDoc,
@@ -789,3 +790,28 @@ def gen3() -> Iterator[Union[Doc, Dict[str, Any]]]:
789790
"age": 45,
790791
"languages": ["es"],
791792
}
793+
794+
795+
@pytest.mark.sync
796+
def test_dense_vector_quantization(client: Elasticsearch) -> None:
797+
class Doc(Document):
798+
float_vector: List[float] = mapped_field(DenseVector())
799+
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
800+
bit_vector: str = mapped_field(DenseVector(element_type="bit"))
801+
802+
class Index:
803+
name = "vectors"
804+
805+
Doc._index.delete(ignore_unavailable=True)
806+
Doc.init()
807+
808+
doc = Doc(
809+
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
810+
)
811+
doc.save(refresh=True)
812+
813+
docs = Doc.search().execute()
814+
assert len(docs) == 1
815+
assert docs[0].float_vector == doc.float_vector
816+
assert docs[0].byte_vector == doc.byte_vector
817+
assert docs[0].bit_vector == doc.bit_vector

0 commit comments

Comments
 (0)