Skip to content

Commit 9b21ff7

Browse files
support new dense vector quantization in 8.16
1 parent 0dd69f8 commit 9b21ff7

File tree

3 files changed

+92
-3
lines changed

3 files changed

+92
-3
lines changed

elasticsearch_dsl/field.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,13 +389,23 @@ def _deserialize(self, data: Any) -> float:
389389
return float(data)
390390

391391

392-
class DenseVector(Float):
392+
class DenseVector(Field):
393393
name = "dense_vector"
394+
_coerce = True
394395

395396
def __init__(self, **kwargs: Any):
396-
kwargs["multi"] = True
397+
self._element_type = kwargs.get("element_type", "float")
398+
if self._element_type in ["float", "byte"]:
399+
kwargs["multi"] = True
397400
super().__init__(**kwargs)
398401

402+
def _deserialize(self, data: Any) -> Any:
403+
if self._element_type == "float":
404+
return float(data)
405+
elif self._element_type == "byte":
406+
return int(data)
407+
return data
408+
399409

400410
class SparseVector(Field):
401411
name = "sparse_vector"

tests/test_integration/_async/test_document.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from datetime import datetime
2525
from ipaddress import ip_address
26-
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Union
26+
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Tuple, Union
2727

2828
import pytest
2929
from elasticsearch import AsyncElasticsearch, ConflictError, NotFoundError
@@ -37,6 +37,7 @@
3737
Binary,
3838
Boolean,
3939
Date,
40+
DenseVector,
4041
Double,
4142
InnerDoc,
4243
Ip,
@@ -795,3 +796,55 @@ async def gen3() -> AsyncIterator[Union[Doc, Dict[str, Any]]]:
795796
"age": 45,
796797
"languages": ["es"],
797798
}
799+
800+
801+
@pytest.mark.asyncio
802+
async def test_float_dense_vector(async_client: AsyncElasticsearch) -> None:
803+
if es_version >= (8, 16):
804+
pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older")
805+
806+
class Doc(AsyncDocument):
807+
float_vector: List[float] = mapped_field(DenseVector())
808+
809+
class Index:
810+
name = "vectors"
811+
812+
await Doc._index.delete(ignore_unavailable=True)
813+
await Doc.init()
814+
815+
doc = Doc(
816+
float_vector=[1.0, 1.2, 2.3]
817+
)
818+
await doc.save(refresh=True)
819+
820+
docs = await Doc.search().execute()
821+
assert len(docs) == 1
822+
assert docs[0].float_vector == doc.float_vector
823+
824+
825+
@pytest.mark.asyncio
826+
async def test_dense_vector(async_client: AsyncElasticsearch, es_version: Tuple[int, ...]) -> None:
827+
if es_version < (8, 16):
828+
pytest.skip("this test requires Elasticsearch 8.16 or newer")
829+
830+
class Doc(AsyncDocument):
831+
float_vector: List[float] = mapped_field(DenseVector())
832+
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
833+
bit_vector: str = mapped_field(DenseVector(element_type="bit"))
834+
835+
class Index:
836+
name = "vectors"
837+
838+
await Doc._index.delete(ignore_unavailable=True)
839+
await Doc.init()
840+
841+
doc = Doc(
842+
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
843+
)
844+
await doc.save(refresh=True)
845+
846+
docs = await Doc.search().execute()
847+
assert len(docs) == 1
848+
assert docs[0].float_vector == doc.float_vector
849+
assert docs[0].byte_vector == doc.byte_vector
850+
assert docs[0].bit_vector == doc.bit_vector

tests/test_integration/_sync/test_document.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
Binary,
3636
Boolean,
3737
Date,
38+
DenseVector,
3839
Document,
3940
Double,
4041
InnerDoc,
@@ -789,3 +790,28 @@ def gen3() -> Iterator[Union[Doc, Dict[str, Any]]]:
789790
"age": 45,
790791
"languages": ["es"],
791792
}
793+
794+
795+
@pytest.mark.sync
796+
def test_dense_vector_quantization(client: Elasticsearch) -> None:
797+
class Doc(Document):
798+
float_vector: List[float] = mapped_field(DenseVector())
799+
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
800+
bit_vector: str = mapped_field(DenseVector(element_type="bit"))
801+
802+
class Index:
803+
name = "vectors"
804+
805+
Doc._index.delete(ignore_unavailable=True)
806+
Doc.init()
807+
808+
doc = Doc(
809+
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
810+
)
811+
doc.save(refresh=True)
812+
813+
docs = Doc.search().execute()
814+
assert len(docs) == 1
815+
assert docs[0].float_vector == doc.float_vector
816+
assert docs[0].byte_vector == doc.byte_vector
817+
assert docs[0].bit_vector == doc.bit_vector

0 commit comments

Comments
 (0)